diff options
author | samantharitter <samantha.ritter@10gen.com> | 2017-06-27 12:09:40 -0400 |
---|---|---|
committer | Jason Carey <jcarey@argv.me> | 2017-07-13 17:40:53 -0400 |
commit | e1cae24805e3e7282958ee67a01555dd6ce40039 (patch) | |
tree | ebce77d9a502a193784483b2201b65e1a5010d98 /src/mongo/db | |
parent | 9a49ee3a03e02597086e577f06a71a0723bc0582 (diff) | |
download | mongo-e1cae24805e3e7282958ee67a01555dd6ce40039.tar.gz |
SERVER-29610 Allow LogicalSessionIds to contain signed user information
Diffstat (limited to 'src/mongo/db')
37 files changed, 1707 insertions, 787 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index eba59679fc9..ae96b374c7f 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -448,6 +448,7 @@ env.Library( 'operation_context_group.cpp' ], LIBDEPS=[ + '$BUILD_DIR/mongo/db/logical_session_id', '$BUILD_DIR/mongo/util/clock_sources', '$BUILD_DIR/mongo/util/concurrency/spin_lock', '$BUILD_DIR/mongo/util/decorable', @@ -455,7 +456,6 @@ env.Library( '$BUILD_DIR/mongo/util/net/network', '$BUILD_DIR/mongo/util/periodic_runner', '$BUILD_DIR/mongo/transport/transport_layer_common', - 'logical_session_cache', ], ) @@ -894,6 +894,32 @@ env.CppUnitTest( ) env.Library( + target='signed_logical_session_id', + source=[ + 'signed_logical_session_id.cpp', + env.Idlc('signed_logical_session_id.idl')[0], + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/base', + '$BUILD_DIR/mongo/crypto/sha1_block', + '$BUILD_DIR/mongo/idl/idl_parser', + '$BUILD_DIR/mongo/util/uuid', + 'logical_session_id' + ], +) + +env.CppUnitTest( + target='signed_logical_session_id_test', + source=[ + 'signed_logical_session_id_test.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/base', + 'signed_logical_session_id', + ], +) + +env.Library( target='logical_session_record', source=[ 'logical_session_record.cpp', @@ -903,7 +929,8 @@ env.Library( '$BUILD_DIR/mongo/base', '$BUILD_DIR/mongo/db/auth/user_name', '$BUILD_DIR/mongo/idl/idl_parser', - 'logical_session_id' + 'logical_session_id', + 'signed_logical_session_id', ], ) @@ -925,6 +952,12 @@ env.Library( 'service_liason.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/crypto/sha1_block', + 'keys_collection_document', + 'keys_collection_manager', + 'keys_collection_manager_zero', + 'logical_clock', + 'signed_logical_session_id', ], ) @@ -985,6 +1018,7 @@ env.Library( 'logical_session_cache.cpp', ], LIBDEPS=[ + 'logical_session_id', 'logical_session_record', 'sessions_collection', 'server_parameters', @@ -1000,9 +1034,13 @@ envWithAsio.CppUnitTest( LIBDEPS=[ '$BUILD_DIR/mongo/db/service_context_noop_init', '$BUILD_DIR/mongo/executor/async_timer_mock', + 'keys_collection_manager', + 'keys_collection_document', + 'logical_clock', 'logical_session_cache', 'service_liason_mock', 'sessions_collection_mock', + 'signed_logical_session_id', ], ) @@ -1079,12 +1117,46 @@ env.Library( target='keys_collection_manager', source=[ 'keys_collection_manager.cpp', + ], + LIBDEPS=[ + ], +) + +env.Library( + target='keys_collection_manager_direct', + source=[ + 'keys_collection_manager_direct.cpp', + ], + LIBDEPS=[ + 'dbdirectclient', + 'keys_collection_manager', + 'logical_clock', + 'logical_time', + ], +) + +env.Library( + target='keys_collection_manager_zero', + source=[ + 'keys_collection_manager_zero.cpp', + ], + LIBDEPS=[ + 'keys_collection_manager', + 'logical_time', + ], +) + +env.Library( + target='keys_collection_manager_sharding', + source=[ + 'keys_collection_manager_sharding.cpp', 'keys_collection_cache_reader.cpp', 'keys_collection_cache_reader_and_updater.cpp', ], LIBDEPS=[ 'logical_clock', 'keys_collection_document', + 'keys_collection_manager', 'logical_time', 'server_options', '$BUILD_DIR/mongo/s/catalog/sharding_catalog_client', @@ -1109,7 +1181,7 @@ env.Library( 'logical_time_validator.cpp', ], LIBDEPS=[ - 'keys_collection_manager', + 'keys_collection_manager_sharding', 'service_context', 'signed_logical_time', 'time_proof_service', @@ -1167,6 +1239,7 @@ env.CppUnitTest( 'logical_time_validator_test.cpp', ], LIBDEPS=[ + 'keys_collection_manager_sharding', 'logical_time_validator', '$BUILD_DIR/mongo/s/config_server_test_fixture', '$BUILD_DIR/mongo/s/coreshard', @@ -1233,12 +1306,12 @@ env.CppUnitTest( ) env.CppUnitTest( - target='keys_collection_manager_test', + target='keys_collection_manager_sharding_test', source=[ - 'keys_collection_manager_test.cpp', + 'keys_collection_manager_sharding_test.cpp', ], LIBDEPS=[ - 'keys_collection_manager', + 'keys_collection_manager_sharding', '$BUILD_DIR/mongo/s/config_server_test_fixture', ], ) diff --git a/src/mongo/db/db.cpp b/src/mongo/db/db.cpp index 64446785910..569958ccbad 100644 --- a/src/mongo/db/db.cpp +++ b/src/mongo/db/db.cpp @@ -73,6 +73,7 @@ #include "mongo/db/initialize_snmp.h" #include "mongo/db/introspect.h" #include "mongo/db/json.h" +#include "mongo/db/keys_collection_manager.h" #include "mongo/db/log_process_details.h" #include "mongo/db/logical_clock.h" #include "mongo/db/logical_session_cache.h" @@ -724,7 +725,7 @@ ExitCode _initAndListen(int listenPort) { } auto sessionCache = makeLogicalSessionCacheD(kind); - globalServiceContext->setLogicalSessionCache(std::move(sessionCache)); + LogicalSessionCache::set(globalServiceContext, std::move(sessionCache)); // MessageServer::run will return when exit code closes its socket and we don't need the // operation context anymore diff --git a/src/mongo/db/keys_collection_manager.cpp b/src/mongo/db/keys_collection_manager.cpp index e91d9af8dad..bed56f87847 100644 --- a/src/mongo/db/keys_collection_manager.cpp +++ b/src/mongo/db/keys_collection_manager.cpp @@ -30,317 +30,10 @@ #include "mongo/db/keys_collection_manager.h" -#include "mongo/db/keys_collection_cache_reader.h" -#include "mongo/db/keys_collection_cache_reader_and_updater.h" -#include "mongo/db/logical_clock.h" -#include "mongo/db/logical_time.h" -#include "mongo/db/operation_context.h" -#include "mongo/db/server_options.h" -#include "mongo/db/service_context.h" -#include "mongo/stdx/memory.h" -#include "mongo/util/concurrency/idle_thread_block.h" -#include "mongo/util/fail_point_service.h" -#include "mongo/util/mongoutils/str.h" -#include "mongo/util/time_support.h" - namespace mongo { const Seconds KeysCollectionManager::kKeyValidInterval{3 * 30 * 24 * 60 * 60}; // ~3 months -namespace { - -Milliseconds kDefaultRefreshWaitTime(30 * 1000); -Milliseconds kRefreshIntervalIfErrored(200); -Milliseconds kMaxRefreshWaitTime(10 * 60 * 1000); - -// Prevents the refresher thread from waiting longer than the given number of milliseconds, even on -// a successful refresh. -MONGO_FP_DECLARE(maxKeyRefreshWaitTimeOverrideMS); - -/** - * Returns the amount of time to wait until the monitoring thread should attempt to refresh again. - */ -Milliseconds howMuchSleepNeedFor(const LogicalTime& currentTime, - const LogicalTime& latestExpiredAt, - const Milliseconds& interval) { - auto currentSecs = currentTime.asTimestamp().getSecs(); - auto expiredSecs = latestExpiredAt.asTimestamp().getSecs(); - - if (currentSecs >= expiredSecs) { - // This means that the last round didn't generate a usable key for the current time. - // However, we don't want to poll too hard as well, so use a low interval. - return kRefreshIntervalIfErrored; - } - - auto millisBeforeExpire = 1000 * (expiredSecs - currentSecs); - - if (interval.count() <= millisBeforeExpire) { - return interval; - } - - return Milliseconds(millisBeforeExpire); -} - -} // unnamed namespace - -KeysCollectionManager::KeysCollectionManager(std::string purpose, - ShardingCatalogClient* client, - Seconds keyValidForInterval) - : _purpose(std::move(purpose)), - _keyValidForInterval(keyValidForInterval), - _catalogClient(client), - _keysCache(_purpose, client) {} - -StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForValidation( - OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) { - auto keyStatus = _getKeyWithKeyIdCheck(keyId, forThisTime); - - if (keyStatus != ErrorCodes::KeyNotFound) { - return keyStatus; - } - - _refresher.refreshNow(opCtx); - - return _getKeyWithKeyIdCheck(keyId, forThisTime); -} - -StatusWith<KeysCollectionDocument> KeysCollectionManager::getKeyForSigning( - const LogicalTime& forThisTime) { - return _getKey(forThisTime); -} - -StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKeyWithKeyIdCheck( - long long keyId, const LogicalTime& forThisTime) { - auto keyStatus = _keysCache.getKeyById(keyId, forThisTime); - - if (!keyStatus.isOK()) { - return keyStatus; - } - - return keyStatus.getValue(); -} - -StatusWith<KeysCollectionDocument> KeysCollectionManager::_getKey(const LogicalTime& forThisTime) { - auto keyStatus = _keysCache.getKey(forThisTime); - - if (!keyStatus.isOK()) { - return keyStatus; - } - - const auto& key = keyStatus.getValue(); - - if (key.getExpiresAt() < forThisTime) { - return {ErrorCodes::KeyNotFound, - str::stream() << "No keys found for " << _purpose << " that is valid for " - << forThisTime.toString()}; - } - - return key; -} - -void KeysCollectionManager::refreshNow(OperationContext* opCtx) { - _refresher.refreshNow(opCtx); -} - -void KeysCollectionManager::startMonitoring(ServiceContext* service) { - _refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); }); - _refresher.start( - service, str::stream() << "monitoring keys for " << _purpose, _keyValidForInterval); -} - -void KeysCollectionManager::stopMonitoring() { - _refresher.stop(); -} - -void KeysCollectionManager::enableKeyGenerator(OperationContext* opCtx, bool doEnable) { - if (doEnable) { - _refresher.switchFunc(opCtx, [this](OperationContext* opCtx) { - KeysCollectionCacheReaderAndUpdater keyGenerator( - _purpose, _catalogClient, _keyValidForInterval); - auto keyGenerationStatus = keyGenerator.refresh(opCtx); - - if (ErrorCodes::isShutdownError(keyGenerationStatus.getStatus().code())) { - return keyGenerationStatus; - } - - // An error encountered by the keyGenerator should not prevent refreshing the cache - auto cacheRefreshStatus = _keysCache.refresh(opCtx); - - if (!keyGenerationStatus.isOK()) { - return keyGenerationStatus; - } - - return cacheRefreshStatus; - }); - } else { - _refresher.switchFunc( - opCtx, [this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); }); - } -} - -bool KeysCollectionManager::hasSeenKeys() { - return _refresher.hasSeenKeys(); -} - -void KeysCollectionManager::PeriodicRunner::refreshNow(OperationContext* opCtx) { - auto refreshRequest = [this]() { - stdx::lock_guard<stdx::mutex> lk(_mutex); - - if (_inShutdown) { - throw DBException("aborting keys cache refresh because node is shutting down", - ErrorCodes::ShutdownInProgress); - } - - if (_refreshRequest) { - return _refreshRequest; - } - - _refreshNeededCV.notify_all(); - _refreshRequest = std::make_shared<Notification<void>>(); - return _refreshRequest; - }(); - - // note: waitFor waits min(maxTimeMS, kDefaultRefreshWaitTime). - // waitFor also throws if timeout, so also throw when notification was not satisfied after - // waiting. - if (!refreshRequest->waitFor(opCtx, kDefaultRefreshWaitTime)) { - throw DBException("timed out waiting for refresh", ErrorCodes::ExceededTimeLimit); - } -} - -void KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh(ServiceContext* service, - std::string threadName, - Milliseconds refreshInterval) { - Client::initThreadIfNotAlready(threadName); - - while (true) { - auto opCtx = cc().makeOperationContext(); - - bool hasRefreshRequestInitially = false; - unsigned errorCount = 0; - std::shared_ptr<RefreshFunc> doRefresh; - { - stdx::lock_guard<stdx::mutex> lock(_mutex); - - if (_inShutdown) { - break; - } - - invariant(_doRefresh.get() != nullptr); - doRefresh = _doRefresh; - hasRefreshRequestInitially = _refreshRequest.get() != nullptr; - } - - Milliseconds nextWakeup = kRefreshIntervalIfErrored; - - // No need to refresh keys in FCV 3.4, since key generation will be disabled. - if (serverGlobalParams.featureCompatibility.version.load() != - ServerGlobalParams::FeatureCompatibility::Version::k34) { - auto latestKeyStatusWith = (*doRefresh)(opCtx.get()); - if (latestKeyStatusWith.getStatus().isOK()) { - errorCount = 0; - const auto& latestKey = latestKeyStatusWith.getValue(); - auto currentTime = LogicalClock::get(service)->getClusterTime(); - - { - stdx::unique_lock<stdx::mutex> lock(_mutex); - _hasSeenKeys = true; - } - - nextWakeup = - howMuchSleepNeedFor(currentTime, latestKey.getExpiresAt(), refreshInterval); - } else { - errorCount += 1; - nextWakeup = Milliseconds(kRefreshIntervalIfErrored.count() * errorCount); - if (nextWakeup > kMaxRefreshWaitTime) { - nextWakeup = kMaxRefreshWaitTime; - } - } - } else { - nextWakeup = kDefaultRefreshWaitTime; - } - - MONGO_FAIL_POINT_BLOCK(maxKeyRefreshWaitTimeOverrideMS, data) { - const BSONObj& dataObj = data.getData(); - auto overrideMS = Milliseconds(dataObj["overrideMS"].numberInt()); - if (nextWakeup > overrideMS) { - nextWakeup = overrideMS; - } - } - - stdx::unique_lock<stdx::mutex> lock(_mutex); - - if (_refreshRequest) { - if (!hasRefreshRequestInitially) { - // A fresh request came in, fulfill the request before going to sleep. - continue; - } - - _refreshRequest->set(); - _refreshRequest.reset(); - } - - if (_inShutdown) { - break; - } - - MONGO_IDLE_THREAD_BLOCK; - auto sleepStatus = opCtx->waitForConditionOrInterruptNoAssertUntil( - _refreshNeededCV, lock, Date_t::now() + nextWakeup); - - if (ErrorCodes::isShutdownError(sleepStatus.getStatus().code())) { - break; - } - } - - stdx::unique_lock<stdx::mutex> lock(_mutex); - if (_refreshRequest) { - _refreshRequest->set(); - _refreshRequest.reset(); - } -} - -void KeysCollectionManager::PeriodicRunner::setFunc(RefreshFunc newRefreshStrategy) { - stdx::lock_guard<stdx::mutex> lock(_mutex); - _doRefresh = std::make_shared<RefreshFunc>(std::move(newRefreshStrategy)); -} - -void KeysCollectionManager::PeriodicRunner::switchFunc(OperationContext* opCtx, - RefreshFunc newRefreshStrategy) { - setFunc(newRefreshStrategy); -} - -void KeysCollectionManager::PeriodicRunner::start(ServiceContext* service, - const std::string& threadName, - Milliseconds refreshInterval) { - stdx::lock_guard<stdx::mutex> lock(_mutex); - invariant(!_backgroundThread.joinable()); - invariant(!_inShutdown); - - _backgroundThread = - stdx::thread(stdx::bind(&KeysCollectionManager::PeriodicRunner::_doPeriodicRefresh, - this, - service, - threadName, - refreshInterval)); -} - -void KeysCollectionManager::PeriodicRunner::stop() { - { - stdx::lock_guard<stdx::mutex> lock(_mutex); - if (!_backgroundThread.joinable()) { - return; - } - - _inShutdown = true; - _refreshNeededCV.notify_all(); - } - - _backgroundThread.join(); -} -bool KeysCollectionManager::PeriodicRunner::hasSeenKeys() { - stdx::lock_guard<stdx::mutex> lock(_mutex); - return _hasSeenKeys; -} +KeysCollectionManager::~KeysCollectionManager() = default; } // namespace mongo diff --git a/src/mongo/db/keys_collection_manager.h b/src/mongo/db/keys_collection_manager.h index 53de6257b32..87b57ef1250 100644 --- a/src/mongo/db/keys_collection_manager.h +++ b/src/mongo/db/keys_collection_manager.h @@ -31,21 +31,13 @@ #include <memory> #include "mongo/base/status_with.h" -#include "mongo/db/keys_collection_cache_reader.h" -#include "mongo/db/keys_collection_cache_reader_and_updater.h" #include "mongo/db/keys_collection_document.h" -#include "mongo/stdx/functional.h" -#include "mongo/stdx/mutex.h" -#include "mongo/stdx/thread.h" -#include "mongo/util/concurrency/notification.h" -#include "mongo/util/duration.h" namespace mongo { class OperationContext; class LogicalTime; class ServiceContext; -class ShardingCatalogClient; /** * This is responsible for providing keys that can be used for HMAC computation. This also supports @@ -55,19 +47,16 @@ class KeysCollectionManager { public: static const Seconds kKeyValidInterval; - KeysCollectionManager(std::string purpose, - ShardingCatalogClient* client, - Seconds keyValidForInterval); + virtual ~KeysCollectionManager(); /** * Return a key that is valid for the given time and also matches the keyId. Note that this call - * can block if it will need to do a refresh. + * can block if it will need to do a refresh and we are on a sharded cluster. * * Throws ErrorCode::ExceededTimeLimit if it times out. */ - StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx, - long long keyId, - const LogicalTime& forThisTime); + virtual StatusWith<KeysCollectionDocument> getKeyForValidation( + OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) = 0; /** * Returns a key that is valid for the given time. Note that unlike getKeyForValidation, this @@ -75,120 +64,8 @@ public: * * Throws ErrorCode::ExceededTimeLimit if it times out. */ - StatusWith<KeysCollectionDocument> getKeyForSigning(const LogicalTime& forThisTime); - - /** - * Request this manager to perform a refresh. - */ - void refreshNow(OperationContext* opCtx); - - /** - * Starts a background thread that will constantly update the internal cache of keys. - * - * Cannot call this after stopMonitoring was called at least once. - */ - void startMonitoring(ServiceContext* service); - - /** - * Stops the background thread updating the cache. - */ - void stopMonitoring(); - - /** - * Enable writing new keys to the config server primary. Should only be called if current node - * is the config primary. - */ - void enableKeyGenerator(OperationContext* opCtx, bool doEnable); - - /** - * Returns true if the refresher has ever successfully returned keys from the config server. - */ - bool hasSeenKeys(); - -private: - /** - * This is responsible for periodically performing refresh in the background. - */ - class PeriodicRunner { - public: - using RefreshFunc = stdx::function<StatusWith<KeysCollectionDocument>(OperationContext*)>; - - /** - * Preemptively inform the monitoring thread it needs to perform a refresh. Returns an - * object - * that gets notified after the current round of refresh is over. Note that being notified - * can - * mean either of these things: - * - * 1. An error occurred and refresh was not performed. - * 2. No error occurred but no new key was found. - * 3. No error occurred and new keys were found. - */ - void refreshNow(OperationContext* opCtx); - - /** - * Sets the refresh function to use. - * Should only be used to bootstrap this refresher with initial strategy. - */ - void setFunc(RefreshFunc newRefreshStrategy); - - /** - * Switches the current strategy with a new one. This also waits to make sure that the old - * strategy is not being used and will no longer be used after this call. - */ - void switchFunc(OperationContext* opCtx, RefreshFunc newRefreshStrategy); - - /** - * Starts the refresh thread. - */ - void start(ServiceContext* service, - const std::string& threadName, - Milliseconds refreshInterval); - - /** - * Stops the refresh thread. - */ - void stop(); - - /** - * Returns true if keys have ever successfully been returned from the config server. - */ - bool hasSeenKeys(); - - private: - void _doPeriodicRefresh(ServiceContext* service, - std::string threadName, - Milliseconds refreshInterval); - - stdx::mutex _mutex; // protects all the member variables below. - std::shared_ptr<Notification<void>> _refreshRequest; - stdx::condition_variable _refreshNeededCV; - - stdx::thread _backgroundThread; - std::shared_ptr<RefreshFunc> _doRefresh; - - bool _hasSeenKeys = false; - bool _inShutdown = false; - }; - - /** - * Return a key that is valid for the given time and also matches the keyId. - */ - StatusWith<KeysCollectionDocument> _getKeyWithKeyIdCheck(long long keyId, - const LogicalTime& forThisTime); - - /** - * Return a key that is valid for the given time. - */ - StatusWith<KeysCollectionDocument> _getKey(const LogicalTime& forThisTime); - - const std::string _purpose; - const Seconds _keyValidForInterval; - ShardingCatalogClient* _catalogClient; - - // No mutex needed since the members below have their own mutexes. - KeysCollectionCacheReader _keysCache; - PeriodicRunner _refresher; + virtual StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx, + const LogicalTime& forThisTime) = 0; }; } // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_direct.cpp b/src/mongo/db/keys_collection_manager_direct.cpp new file mode 100644 index 00000000000..60cf5672e6e --- /dev/null +++ b/src/mongo/db/keys_collection_manager_direct.cpp @@ -0,0 +1,137 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/keys_collection_manager_direct.h" + +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/logical_clock.h" +#include "mongo/db/logical_time.h" +#include "mongo/db/operation_context.h" +#include "mongo/db/server_options.h" +#include "mongo/db/service_context.h" + +namespace mongo { + +namespace { +const char kLogicalTimeKeysCollection[] = "admin.system.keys"; +const int kMaxCachedKeys = 20; +} // namespace + +KeysCollectionManagerDirect::KeysCollectionManagerDirect(std::string purpose, + Seconds keyValidForInterval) + : _purpose(std::move(purpose)), + _keyValidForInterval(keyValidForInterval), + _cache(kMaxCachedKeys) {} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerDirect::getKeyForValidation( + OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) { + // First, attempt to find the key in our cache. + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + auto it = _cache.find(keyId); + if (it != _cache.end()) { + return it->second; + } + } + + // Query admin.system.keys for an active key with this id. + DBDirectClient client(opCtx); + + BSONObjBuilder queryBuilder; + queryBuilder.append("purpose", _purpose); + queryBuilder.append("_id", keyId); + queryBuilder.append("expiresAt", BSON("$gt" << forThisTime.asTimestamp())); + + auto cursor = client.query(KeysCollectionDocument::ConfigNS, queryBuilder.obj()); + + if (!cursor->more()) { + return {ErrorCodes::KeyNotFound, "Could not find matching key"}; + } + + // Parse the key. + auto res = KeysCollectionDocument::fromBSON(cursor->next()); + if (!res.isOK()) { + return res.getStatus(); + } + + // Add to our cache. + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _cache.add(keyId, res.getValue()); + } + + return res.getValue(); +} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerDirect::getKeyForSigning( + OperationContext* opCtx, const LogicalTime& forThisTime) { + // Search through the cache for active keys. + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + for (auto& it : _cache) { + auto keyDoc = it.second; + auto expiration = keyDoc.getExpiresAt(); + if (expiration > forThisTime) { + return keyDoc; + } + } + } + + // Query admin.system.keys for active keys. + DBDirectClient client(opCtx); + + BSONObjBuilder queryBuilder; + queryBuilder.append("purpose", _purpose); + queryBuilder.append("expiresAt", BSON("$gt" << forThisTime.asTimestamp())); + + auto cursor = client.query(KeysCollectionDocument::ConfigNS, queryBuilder.obj()); + + if (!cursor->more()) { + return {ErrorCodes::KeyNotFound, "Could not find an active key for signing"}; + } + + // Parse and return the key. + auto res = KeysCollectionDocument::fromBSON(cursor->next()); + if (!res.isOK()) { + return res.getStatus(); + } + + auto keyDoc = res.getValue(); + + // Add to our cache. + { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _cache.add(keyDoc.getKeyId(), keyDoc); + } + + return keyDoc; +} + +} // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_direct.h b/src/mongo/db/keys_collection_manager_direct.h new file mode 100644 index 00000000000..2f1228ff056 --- /dev/null +++ b/src/mongo/db/keys_collection_manager_direct.h @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2017 MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <memory> + +#include "mongo/db/keys_collection_document.h" +#include "mongo/db/keys_collection_manager.h" +#include "mongo/stdx/mutex.h" +#include "mongo/util/lru_cache.h" + +namespace mongo { + +class OperationContext; +class LogicalTime; +class ServiceContext; + +/** + * This implementation of the KeysCollectionManager uses DBDirectclient to query the + * keys collection local to this server. + */ +class KeysCollectionManagerDirect : public KeysCollectionManager { +public: + KeysCollectionManagerDirect(std::string purpose, Seconds keyValidForInterval); + + /** + * Return a key that is valid for the given time and also matches the keyId. + */ + StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx, + long long keyId, + const LogicalTime& forThisTime) override; + + /** + * Returns a key that is valid for the given time. + */ + StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx, + const LogicalTime& forThisTime) override; + +private: + const std::string _purpose; + const Seconds _keyValidForInterval; + + stdx::mutex _mutex; + LRUCache<long long, KeysCollectionDocument> _cache; +}; + +} // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_sharding.cpp b/src/mongo/db/keys_collection_manager_sharding.cpp new file mode 100644 index 00000000000..c43779cbcbe --- /dev/null +++ b/src/mongo/db/keys_collection_manager_sharding.cpp @@ -0,0 +1,345 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/keys_collection_manager_sharding.h" + +#include "mongo/db/keys_collection_cache_reader.h" +#include "mongo/db/keys_collection_cache_reader_and_updater.h" +#include "mongo/db/logical_clock.h" +#include "mongo/db/logical_time.h" +#include "mongo/db/operation_context.h" +#include "mongo/db/server_options.h" +#include "mongo/db/service_context.h" +#include "mongo/stdx/memory.h" +#include "mongo/util/concurrency/idle_thread_block.h" +#include "mongo/util/fail_point_service.h" +#include "mongo/util/mongoutils/str.h" +#include "mongo/util/time_support.h" + +namespace mongo { + +namespace { + +Milliseconds kDefaultRefreshWaitTime(30 * 1000); +Milliseconds kRefreshIntervalIfErrored(200); +Milliseconds kMaxRefreshWaitTime(10 * 60 * 1000); + +// Prevents the refresher thread from waiting longer than the given number of milliseconds, even on +// a successful refresh. +MONGO_FP_DECLARE(maxKeyRefreshWaitTimeOverrideMS); + +/** + * Returns the amount of time to wait until the monitoring thread should attempt to refresh again. + */ +Milliseconds howMuchSleepNeedFor(const LogicalTime& currentTime, + const LogicalTime& latestExpiredAt, + const Milliseconds& interval) { + auto currentSecs = currentTime.asTimestamp().getSecs(); + auto expiredSecs = latestExpiredAt.asTimestamp().getSecs(); + + if (currentSecs >= expiredSecs) { + // This means that the last round didn't generate a usable key for the current time. + // However, we don't want to poll too hard as well, so use a low interval. + return kRefreshIntervalIfErrored; + } + + auto millisBeforeExpire = 1000 * (expiredSecs - currentSecs); + + if (interval.count() <= millisBeforeExpire) { + return interval; + } + + return Milliseconds(millisBeforeExpire); +} + +} // unnamed namespace + +KeysCollectionManagerSharding::KeysCollectionManagerSharding(std::string purpose, + ShardingCatalogClient* client, + Seconds keyValidForInterval) + : _purpose(std::move(purpose)), + _keyValidForInterval(keyValidForInterval), + _catalogClient(client), + _keysCache(_purpose, client) {} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::getKeyForValidation( + OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) { + auto keyStatus = _getKeyWithKeyIdCheck(keyId, forThisTime); + + if (keyStatus != ErrorCodes::KeyNotFound) { + return keyStatus; + } + + _refresher.refreshNow(opCtx); + + return _getKeyWithKeyIdCheck(keyId, forThisTime); +} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::getKeyForSigning( + OperationContext* opCtx, const LogicalTime& forThisTime) { + return _getKey(forThisTime); +} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::_getKeyWithKeyIdCheck( + long long keyId, const LogicalTime& forThisTime) { + auto keyStatus = _keysCache.getKeyById(keyId, forThisTime); + + if (!keyStatus.isOK()) { + return keyStatus; + } + + return keyStatus.getValue(); +} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerSharding::_getKey( + const LogicalTime& forThisTime) { + auto keyStatus = _keysCache.getKey(forThisTime); + + if (!keyStatus.isOK()) { + return keyStatus; + } + + const auto& key = keyStatus.getValue(); + + if (key.getExpiresAt() < forThisTime) { + return {ErrorCodes::KeyNotFound, + str::stream() << "No keys found for " << _purpose << " that is valid for " + << forThisTime.toString()}; + } + + return key; +} + +void KeysCollectionManagerSharding::refreshNow(OperationContext* opCtx) { + _refresher.refreshNow(opCtx); +} + +void KeysCollectionManagerSharding::startMonitoring(ServiceContext* service) { + _refresher.setFunc([this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); }); + _refresher.start( + service, str::stream() << "monitoring keys for " << _purpose, _keyValidForInterval); +} + +void KeysCollectionManagerSharding::stopMonitoring() { + _refresher.stop(); +} + +void KeysCollectionManagerSharding::enableKeyGenerator(OperationContext* opCtx, bool doEnable) { + if (doEnable) { + _refresher.switchFunc(opCtx, [this](OperationContext* opCtx) { + KeysCollectionCacheReaderAndUpdater keyGenerator( + _purpose, _catalogClient, _keyValidForInterval); + auto keyGenerationStatus = keyGenerator.refresh(opCtx); + + if (ErrorCodes::isShutdownError(keyGenerationStatus.getStatus().code())) { + return keyGenerationStatus; + } + + // An error encountered by the keyGenerator should not prevent refreshing the cache + auto cacheRefreshStatus = _keysCache.refresh(opCtx); + + if (!keyGenerationStatus.isOK()) { + return keyGenerationStatus; + } + + return cacheRefreshStatus; + }); + } else { + _refresher.switchFunc( + opCtx, [this](OperationContext* opCtx) { return _keysCache.refresh(opCtx); }); + } +} + +bool KeysCollectionManagerSharding::hasSeenKeys() { + return _refresher.hasSeenKeys(); +} + +void KeysCollectionManagerSharding::PeriodicRunner::refreshNow(OperationContext* opCtx) { + auto refreshRequest = [this]() { + stdx::lock_guard<stdx::mutex> lk(_mutex); + + if (_inShutdown) { + throw DBException("aborting keys cache refresh because node is shutting down", + ErrorCodes::ShutdownInProgress); + } + + if (_refreshRequest) { + return _refreshRequest; + } + + _refreshNeededCV.notify_all(); + _refreshRequest = std::make_shared<Notification<void>>(); + return _refreshRequest; + }(); + + // note: waitFor waits min(maxTimeMS, kDefaultRefreshWaitTime). + // waitFor also throws if timeout, so also throw when notification was not satisfied after + // waiting. + if (!refreshRequest->waitFor(opCtx, kDefaultRefreshWaitTime)) { + throw DBException("timed out waiting for refresh", ErrorCodes::ExceededTimeLimit); + } +} + +void KeysCollectionManagerSharding::PeriodicRunner::_doPeriodicRefresh( + ServiceContext* service, std::string threadName, Milliseconds refreshInterval) { + Client::initThreadIfNotAlready(threadName); + + while (true) { + auto opCtx = cc().makeOperationContext(); + + bool hasRefreshRequestInitially = false; + unsigned errorCount = 0; + std::shared_ptr<RefreshFunc> doRefresh; + { + stdx::lock_guard<stdx::mutex> lock(_mutex); + + if (_inShutdown) { + break; + } + + invariant(_doRefresh.get() != nullptr); + doRefresh = _doRefresh; + hasRefreshRequestInitially = _refreshRequest.get() != nullptr; + } + + Milliseconds nextWakeup = kRefreshIntervalIfErrored; + + // No need to refresh keys in FCV 3.4, since key generation will be disabled. + if (serverGlobalParams.featureCompatibility.version.load() != + ServerGlobalParams::FeatureCompatibility::Version::k34) { + auto latestKeyStatusWith = (*doRefresh)(opCtx.get()); + if (latestKeyStatusWith.getStatus().isOK()) { + errorCount = 0; + const auto& latestKey = latestKeyStatusWith.getValue(); + auto currentTime = LogicalClock::get(service)->getClusterTime(); + + { + stdx::unique_lock<stdx::mutex> lock(_mutex); + _hasSeenKeys = true; + } + + nextWakeup = + howMuchSleepNeedFor(currentTime, latestKey.getExpiresAt(), refreshInterval); + } else { + errorCount += 1; + nextWakeup = Milliseconds(kRefreshIntervalIfErrored.count() * errorCount); + if (nextWakeup > kMaxRefreshWaitTime) { + nextWakeup = kMaxRefreshWaitTime; + } + } + } else { + nextWakeup = kDefaultRefreshWaitTime; + } + + MONGO_FAIL_POINT_BLOCK(maxKeyRefreshWaitTimeOverrideMS, data) { + const BSONObj& dataObj = data.getData(); + auto overrideMS = Milliseconds(dataObj["overrideMS"].numberInt()); + if (nextWakeup > overrideMS) { + nextWakeup = overrideMS; + } + } + + stdx::unique_lock<stdx::mutex> lock(_mutex); + + if (_refreshRequest) { + if (!hasRefreshRequestInitially) { + // A fresh request came in, fulfill the request before going to sleep. + continue; + } + + _refreshRequest->set(); + _refreshRequest.reset(); + } + + if (_inShutdown) { + break; + } + + MONGO_IDLE_THREAD_BLOCK; + auto sleepStatus = opCtx->waitForConditionOrInterruptNoAssertUntil( + _refreshNeededCV, lock, Date_t::now() + nextWakeup); + + if (ErrorCodes::isShutdownError(sleepStatus.getStatus().code())) { + break; + } + } + + stdx::unique_lock<stdx::mutex> lock(_mutex); + if (_refreshRequest) { + _refreshRequest->set(); + _refreshRequest.reset(); + } +} + +void KeysCollectionManagerSharding::PeriodicRunner::setFunc(RefreshFunc newRefreshStrategy) { + stdx::lock_guard<stdx::mutex> lock(_mutex); + _doRefresh = std::make_shared<RefreshFunc>(std::move(newRefreshStrategy)); +} + +void KeysCollectionManagerSharding::PeriodicRunner::switchFunc(OperationContext* opCtx, + RefreshFunc newRefreshStrategy) { + setFunc(newRefreshStrategy); +} + +void KeysCollectionManagerSharding::PeriodicRunner::start(ServiceContext* service, + const std::string& threadName, + Milliseconds refreshInterval) { + stdx::lock_guard<stdx::mutex> lock(_mutex); + invariant(!_backgroundThread.joinable()); + invariant(!_inShutdown); + + _backgroundThread = + stdx::thread(stdx::bind(&KeysCollectionManagerSharding::PeriodicRunner::_doPeriodicRefresh, + this, + service, + threadName, + refreshInterval)); +} + +void KeysCollectionManagerSharding::PeriodicRunner::stop() { + { + stdx::lock_guard<stdx::mutex> lock(_mutex); + if (!_backgroundThread.joinable()) { + return; + } + + _inShutdown = true; + _refreshNeededCV.notify_all(); + } + + _backgroundThread.join(); +} + +bool KeysCollectionManagerSharding::PeriodicRunner::hasSeenKeys() { + stdx::lock_guard<stdx::mutex> lock(_mutex); + return _hasSeenKeys; +} + +} // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_sharding.h b/src/mongo/db/keys_collection_manager_sharding.h new file mode 100644 index 00000000000..82f358ebe18 --- /dev/null +++ b/src/mongo/db/keys_collection_manager_sharding.h @@ -0,0 +1,197 @@ +/** + * Copyright (C) 2017 MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <memory> + +#include "mongo/base/status_with.h" +#include "mongo/db/keys_collection_cache_reader.h" +#include "mongo/db/keys_collection_cache_reader_and_updater.h" +#include "mongo/db/keys_collection_document.h" +#include "mongo/db/keys_collection_manager.h" +#include "mongo/stdx/functional.h" +#include "mongo/stdx/mutex.h" +#include "mongo/stdx/thread.h" +#include "mongo/util/concurrency/notification.h" +#include "mongo/util/duration.h" + +namespace mongo { + +class OperationContext; +class LogicalTime; +class ServiceContext; +class ShardingCatalogClient; + +/** + * This implementation of the KeysCollectionManager queries the config servers for keys. + * It maintains in internal background thread that is used to periodically refresh + * the local key cache against the keys collection stored on the config servers. + */ +class KeysCollectionManagerSharding : public KeysCollectionManager { +public: + static const Seconds kKeyValidInterval; + + KeysCollectionManagerSharding(std::string purpose, + ShardingCatalogClient* client, + Seconds keyValidForInterval); + + /** + * Return a key that is valid for the given time and also matches the keyId. Note that this call + * can block if it will need to do a refresh. + * + * Throws ErrorCode::ExceededTimeLimit if it times out. + */ + StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx, + long long keyId, + const LogicalTime& forThisTime) override; + + /** + * Returns a key that is valid for the given time. Note that unlike getKeyForValidation, this + * will never do a refresh. + * + * Throws ErrorCode::ExceededTimeLimit if it times out. + */ + StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx, + const LogicalTime& forThisTime) override; + + /** + * Request this manager to perform a refresh. + */ + void refreshNow(OperationContext* opCtx); + + /** + * Starts a background thread that will constantly update the internal cache of keys. + * + * Cannot call this after stopMonitoring was called at least once. + */ + void startMonitoring(ServiceContext* service); + + /** + * Stops the background thread updating the cache. + */ + void stopMonitoring(); + + /** + * Enable writing new keys to the config server primary. Should only be called if current node + * is the config primary. + */ + void enableKeyGenerator(OperationContext* opCtx, bool doEnable); + + /** + * Returns true if the refresher has ever successfully returned keys from the config server. + */ + bool hasSeenKeys(); + +private: + /** + * This is responsible for periodically performing refresh in the background. + */ + class PeriodicRunner { + public: + using RefreshFunc = stdx::function<StatusWith<KeysCollectionDocument>(OperationContext*)>; + + /** + * Preemptively inform the monitoring thread it needs to perform a refresh. Returns an + * object + * that gets notified after the current round of refresh is over. Note that being notified + * can + * mean either of these things: + * + * 1. An error occurred and refresh was not performed. + * 2. No error occurred but no new key was found. + * 3. No error occurred and new keys were found. + */ + void refreshNow(OperationContext* opCtx); + + /** + * Sets the refresh function to use. + * Should only be used to bootstrap this refresher with initial strategy. + */ + void setFunc(RefreshFunc newRefreshStrategy); + + /** + * Switches the current strategy with a new one. This also waits to make sure that the old + * strategy is not being used and will no longer be used after this call. + */ + void switchFunc(OperationContext* opCtx, RefreshFunc newRefreshStrategy); + + /** + * Starts the refresh thread. + */ + void start(ServiceContext* service, + const std::string& threadName, + Milliseconds refreshInterval); + + /** + * Stops the refresh thread. + */ + void stop(); + + /** + * Returns true if keys have ever successfully been returned from the config server. + */ + bool hasSeenKeys(); + + private: + void _doPeriodicRefresh(ServiceContext* service, + std::string threadName, + Milliseconds refreshInterval); + + stdx::mutex _mutex; // protects all the member variables below. + std::shared_ptr<Notification<void>> _refreshRequest; + stdx::condition_variable _refreshNeededCV; + + stdx::thread _backgroundThread; + std::shared_ptr<RefreshFunc> _doRefresh; + + bool _hasSeenKeys = false; + bool _inShutdown = false; + }; + + /** + * Return a key that is valid for the given time and also matches the keyId. + */ + StatusWith<KeysCollectionDocument> _getKeyWithKeyIdCheck(long long keyId, + const LogicalTime& forThisTime); + + /** + * Return a key that is valid for the given time. + */ + StatusWith<KeysCollectionDocument> _getKey(const LogicalTime& forThisTime); + + const std::string _purpose; + const Seconds _keyValidForInterval; + ShardingCatalogClient* _catalogClient; + + // No mutex needed since the members below have their own mutexes. + KeysCollectionCacheReader _keysCache; + PeriodicRunner _refresher; +}; + +} // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_test.cpp b/src/mongo/db/keys_collection_manager_sharding_test.cpp index e68d9285cc0..d7af13a09fc 100644 --- a/src/mongo/db/keys_collection_manager_test.cpp +++ b/src/mongo/db/keys_collection_manager_sharding_test.cpp @@ -33,7 +33,7 @@ #include "mongo/db/jsobj.h" #include "mongo/db/keys_collection_document.h" -#include "mongo/db/keys_collection_manager.h" +#include "mongo/db/keys_collection_manager_sharding.h" #include "mongo/db/logical_clock.h" #include "mongo/db/namespace_string.h" #include "mongo/db/operation_context.h" @@ -48,9 +48,9 @@ namespace mongo { -class KeysManagerTest : public ConfigServerTestFixture { +class KeysManagerShardedTest : public ConfigServerTestFixture { public: - KeysCollectionManager* keyManager() { + KeysCollectionManagerSharding* keyManager() { return _keyManager.get(); } @@ -65,7 +65,8 @@ protected: auto clockSource = stdx::make_unique<ClockSourceMock>(); operationContext()->getServiceContext()->setFastClockSource(std::move(clockSource)); auto catalogClient = Grid::get(operationContext())->catalogClient(); - _keyManager = stdx::make_unique<KeysCollectionManager>("dummy", catalogClient, Seconds(1)); + _keyManager = + stdx::make_unique<KeysCollectionManagerSharding>("dummy", catalogClient, Seconds(1)); } void tearDown() override { @@ -81,10 +82,10 @@ protected: } private: - std::unique_ptr<KeysCollectionManager> _keyManager; + std::unique_ptr<KeysCollectionManagerSharding> _keyManager; }; -TEST_F(KeysManagerTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) { +TEST_F(KeysManagerShardedTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) { operationContext()->setDeadlineAfterNowBy(Microseconds(250 * 1000)); ASSERT_THROWS(keyManager() @@ -93,7 +94,7 @@ TEST_F(KeysManagerTest, GetKeyForValidationTimesOutIfRefresherIsNotRunning) { DBException); } -TEST_F(KeysManagerTest, GetKeyForValidationErrorsIfKeyDoesntExist) { +TEST_F(KeysManagerShardedTest, GetKeyForValidationErrorsIfKeyDoesntExist) { keyManager()->startMonitoring(getServiceContext()); auto keyStatus = @@ -101,7 +102,7 @@ TEST_F(KeysManagerTest, GetKeyForValidationErrorsIfKeyDoesntExist) { ASSERT_EQ(ErrorCodes::KeyNotFound, keyStatus.getStatus()); } -TEST_F(KeysManagerTest, GetKeyWithSingleKey) { +TEST_F(KeysManagerShardedTest, GetKeyWithSingleKey) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -119,7 +120,7 @@ TEST_F(KeysManagerTest, GetKeyWithSingleKey) { ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp()); } -TEST_F(KeysManagerTest, GetKeyWithMultipleKeys) { +TEST_F(KeysManagerShardedTest, GetKeyWithMultipleKeys) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -151,7 +152,7 @@ TEST_F(KeysManagerTest, GetKeyWithMultipleKeys) { ASSERT_EQ(Timestamp(205, 0), key.getExpiresAt().asTimestamp()); } -TEST_F(KeysManagerTest, GetKeyShouldErrorIfKeyIdMismatchKey) { +TEST_F(KeysManagerShardedTest, GetKeyShouldErrorIfKeyIdMismatchKey) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -164,7 +165,7 @@ TEST_F(KeysManagerTest, GetKeyShouldErrorIfKeyIdMismatchKey) { ASSERT_EQ(ErrorCodes::KeyNotFound, keyStatus.getStatus()); } -TEST_F(KeysManagerTest, GetKeyWithoutRefreshShouldReturnRightKey) { +TEST_F(KeysManagerShardedTest, GetKeyWithoutRefreshShouldReturnRightKey) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -199,7 +200,7 @@ TEST_F(KeysManagerTest, GetKeyWithoutRefreshShouldReturnRightKey) { } } -TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) { +TEST_F(KeysManagerShardedTest, GetKeyForSigningShouldReturnRightKey) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -209,7 +210,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) { keyManager()->refreshNow(operationContext()); - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0))); ASSERT_OK(keyStatus.getStatus()); auto key = keyStatus.getValue(); @@ -218,7 +219,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightKey) { ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp()); } -TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) { +TEST_F(KeysManagerShardedTest, GetKeyForSigningShouldReturnRightOldKey) { keyManager()->startMonitoring(getServiceContext()); KeysCollectionDocument origKey1( @@ -233,7 +234,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) { keyManager()->refreshNow(operationContext()); { - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0))); ASSERT_OK(keyStatus.getStatus()); auto key = keyStatus.getValue(); @@ -243,7 +244,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) { } { - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(105, 0))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(105, 0))); ASSERT_OK(keyStatus.getStatus()); auto key = keyStatus.getValue(); @@ -253,7 +254,7 @@ TEST_F(KeysManagerTest, GetKeyForSigningShouldReturnRightOldKey) { } } -TEST_F(KeysManagerTest, ShouldCreateKeysIfKeyGeneratorEnabled) { +TEST_F(KeysManagerShardedTest, ShouldCreateKeysIfKeyGeneratorEnabled) { keyManager()->startMonitoring(getServiceContext()); const LogicalTime currentTime(LogicalTime(Timestamp(100, 0))); @@ -262,14 +263,14 @@ TEST_F(KeysManagerTest, ShouldCreateKeysIfKeyGeneratorEnabled) { keyManager()->enableKeyGenerator(operationContext(), true); keyManager()->refreshNow(operationContext()); - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 100))); ASSERT_OK(keyStatus.getStatus()); auto key = keyStatus.getValue(); ASSERT_EQ(Timestamp(101, 0), key.getExpiresAt().asTimestamp()); } -TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) { +TEST_F(KeysManagerShardedTest, EnableModeFlipFlopStressTest) { keyManager()->startMonitoring(getServiceContext()); const LogicalTime currentTime(LogicalTime(Timestamp(100, 0))); @@ -281,7 +282,7 @@ TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) { keyManager()->enableKeyGenerator(operationContext(), doEnable); keyManager()->refreshNow(operationContext()); - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 100))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 100))); ASSERT_OK(keyStatus.getStatus()); auto key = keyStatus.getValue(); @@ -291,7 +292,7 @@ TEST_F(KeysManagerTest, EnableModeFlipFlopStressTest) { } } -TEST_F(KeysManagerTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) { +TEST_F(KeysManagerShardedTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) { KeysCollectionDocument origKey1( 1, "dummy", TimeProofService::generateRandomKey(), LogicalTime(Timestamp(105, 0))); ASSERT_OK(insertToConfigCollection( @@ -319,7 +320,7 @@ TEST_F(KeysManagerTest, ShouldStillBeAbleToUpdateCacheEvenIfItCantCreateKeys) { ASSERT_EQ(Timestamp(105, 0), key.getExpiresAt().asTimestamp()); } -TEST_F(KeysManagerTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) { +TEST_F(KeysManagerShardedTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) { const LogicalTime currentTime(Timestamp(100, 0)); LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime); @@ -336,11 +337,11 @@ TEST_F(KeysManagerTest, ShouldNotCreateKeysWithDisableKeyGenerationFailPoint) { // Once the failpoint is disabled, the generator can make keys again. keyManager()->refreshNow(operationContext()); - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0))); ASSERT_OK(keyStatus.getStatus()); } -TEST_F(KeysManagerTest, HasSeenKeysIsFalseUntilKeysAreFound) { +TEST_F(KeysManagerShardedTest, HasSeenKeysIsFalseUntilKeysAreFound) { const LogicalTime currentTime(Timestamp(100, 0)); LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime); @@ -361,13 +362,13 @@ TEST_F(KeysManagerTest, HasSeenKeysIsFalseUntilKeysAreFound) { // Once the failpoint is disabled, the generator can make keys again. keyManager()->refreshNow(operationContext()); - auto keyStatus = keyManager()->getKeyForSigning(LogicalTime(Timestamp(100, 0))); + auto keyStatus = keyManager()->getKeyForSigning(nullptr, LogicalTime(Timestamp(100, 0))); ASSERT_OK(keyStatus.getStatus()); ASSERT_EQ(true, keyManager()->hasSeenKeys()); } -TEST_F(KeysManagerTest, ShouldNotReturnKeysInFeatureCompatibilityVersion34) { +TEST_F(KeysManagerShardedTest, ShouldNotReturnKeysInFeatureCompatibilityVersion34) { serverGlobalParams.featureCompatibility.version.store( ServerGlobalParams::FeatureCompatibility::Version::k34); diff --git a/src/mongo/db/keys_collection_manager_zero.cpp b/src/mongo/db/keys_collection_manager_zero.cpp new file mode 100644 index 00000000000..5f773b0152d --- /dev/null +++ b/src/mongo/db/keys_collection_manager_zero.cpp @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/keys_collection_manager_zero.h" + +#include "mongo/db/keys_collection_document.h" + +namespace mongo { + +namespace { + +const TimeProofService::Key kTimeProofServiceKey{}; +const long long kKeyId = 1; + +} // namespace + +KeysCollectionManagerZero::KeysCollectionManagerZero(std::string purpose) + : _purpose(std::move(purpose)) {} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerZero::getKeyForValidation( + OperationContext* opCtx, long long keyId, const LogicalTime& forThisTime) { + return KeysCollectionDocument(keyId, _purpose, kTimeProofServiceKey, forThisTime); +} + +StatusWith<KeysCollectionDocument> KeysCollectionManagerZero::getKeyForSigning( + OperationContext* opCtx, const LogicalTime& forThisTime) { + return KeysCollectionDocument(kKeyId, _purpose, kTimeProofServiceKey, forThisTime); +} + +} // namespace mongo diff --git a/src/mongo/db/keys_collection_manager_zero.h b/src/mongo/db/keys_collection_manager_zero.h new file mode 100644 index 00000000000..71f2dcc89ee --- /dev/null +++ b/src/mongo/db/keys_collection_manager_zero.h @@ -0,0 +1,69 @@ +/** + * Copyright (C) 2017 MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <memory> +#include <string> + +#include "mongo/db/keys_collection_document.h" +#include "mongo/db/keys_collection_manager.h" + +namespace mongo { + +class OperationContext; +class LogicalTime; +class ServiceContext; + +/** + * This implementation of the KeysCollectionManager always returns a zeroed key. This is a + * transitional type to bridge the gap until we decide how to store and rotate keys in standalones + * and non-clustered replica sets. + */ +class KeysCollectionManagerZero : public KeysCollectionManager { +public: + KeysCollectionManagerZero(std::string purpose); + + /** + * Return a key that is valid for the given time and also matches the keyId. + */ + StatusWith<KeysCollectionDocument> getKeyForValidation(OperationContext* opCtx, + long long keyId, + const LogicalTime& forThisTime) override; + + /** + * Returns a key that is valid for the given time. + */ + StatusWith<KeysCollectionDocument> getKeyForSigning(OperationContext* opCtx, + const LogicalTime& forThisTime) override; + +private: + std::string _purpose; +}; + +} // namespace mongo diff --git a/src/mongo/db/logical_session_cache.cpp b/src/mongo/db/logical_session_cache.cpp index 8aa1eaba802..7474c398159 100644 --- a/src/mongo/db/logical_session_cache.cpp +++ b/src/mongo/db/logical_session_cache.cpp @@ -32,13 +32,20 @@ #include "mongo/db/logical_session_cache.h" +#include "mongo/db/operation_context.h" #include "mongo/db/server_parameters.h" +#include "mongo/db/service_context.h" #include "mongo/util/duration.h" #include "mongo/util/log.h" #include "mongo/util/periodic_runner.h" namespace mongo { +namespace { +const auto getLogicalSessionCache = + ServiceContext::declareDecoration<std::unique_ptr<LogicalSessionCache>>(); +} // namespace + MONGO_EXPORT_STARTUP_SERVER_PARAMETER(logicalSessionRecordCacheSize, int, LogicalSessionCache::kLogicalSessionCacheDefaultCapacity); @@ -55,6 +62,20 @@ constexpr int LogicalSessionCache::kLogicalSessionCacheDefaultCapacity; constexpr Minutes LogicalSessionCache::kLogicalSessionDefaultTimeout; constexpr Minutes LogicalSessionCache::kLogicalSessionDefaultRefresh; +LogicalSessionCache* LogicalSessionCache::get(ServiceContext* service) { + return getLogicalSessionCache(service).get(); +} + +LogicalSessionCache* LogicalSessionCache::get(OperationContext* ctx) { + return get(ctx->getClient()->getServiceContext()); +} + +void LogicalSessionCache::set(ServiceContext* service, + std::unique_ptr<LogicalSessionCache> sessionCache) { + auto& cache = getLogicalSessionCache(service); + cache = std::move(sessionCache); +} + LogicalSessionCache::LogicalSessionCache(std::unique_ptr<ServiceLiason> service, std::unique_ptr<SessionsCollection> collection, Options options) @@ -78,42 +99,34 @@ LogicalSessionCache::~LogicalSessionCache() { } } -StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwner(LogicalSessionId lsid) { +Status LogicalSessionCache::fetchAndPromote(SignedLogicalSessionId slsid) { // Search our local cache first - auto owner = getOwnerFromCache(lsid); - if (owner.isOK()) { - return owner; + auto promoteRes = promote(slsid); + if (promoteRes.isOK()) { + return promoteRes; } // Cache miss, must fetch from the sessions collection. - auto res = _sessionsColl->fetchRecord(lsid); + auto res = _sessionsColl->fetchRecord(slsid); // If we got a valid record, add it to our cache. if (res.isOK()) { auto& record = res.getValue(); record.setLastUse(_service->now()); + // Any duplicate records here are actually the same record with different + // lastUse times, ignore them. auto oldRecord = _addToCache(record); - - // If we had a conflicting record for this id, and they aren't the same record, - // it could mean that an interloper called endSession and startSession for the - // same lsid while we were fetching its record from the sessions collection. - // This means our session has been written over, do not allow the caller to use it. - // Note: we could find expired versions of our same record here, but they'll compare equal. - if (oldRecord && *oldRecord != record) { - return {ErrorCodes::NoSuchSession, "no matching session record found"}; - } - - return record.getSessionOwner(); + return Status::OK(); } + // If we could not get a valid record, return the error. return res.getStatus(); } -StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwnerFromCache( - LogicalSessionId lsid) { +Status LogicalSessionCache::promote(SignedLogicalSessionId slsid) { stdx::unique_lock<stdx::mutex> lk(_cacheMutex); - auto it = _cache.find(lsid); + auto it = _cache.find(slsid.getLsid()); if (it == _cache.end()) { return {ErrorCodes::NoSuchSession, "no matching session record found in the cache"}; } @@ -126,13 +139,13 @@ StatusWith<LogicalSessionRecord::Owner> LogicalSessionCache::getOwnerFromCache( // Update the last use time before returning. it->second.setLastUse(now); - return it->second.getSessionOwner(); + return Status::OK(); } -Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecord) { +Status LogicalSessionCache::startSession(SignedLogicalSessionId slsid) { // Make sure the timestamp makes sense - auto now = _service->now(); - authoritativeRecord.setLastUse(now); + auto authoritativeRecord = + LogicalSessionRecord::makeAuthoritativeRecord(slsid, _service->now()); // Attempt to insert into the sessions collection first. This collection enforces // unique session ids, so it will act as concurrency control for us. @@ -148,7 +161,7 @@ Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecor auto oldRecord = _addToCache(authoritativeRecord); if (oldRecord) { if (*oldRecord != authoritativeRecord) { - if (!_isDead(*oldRecord, now)) { + if (!_isDead(*oldRecord, _service->now())) { return {ErrorCodes::DuplicateSession, "session with this id already exists"}; } } @@ -157,6 +170,17 @@ Status LogicalSessionCache::startSession(LogicalSessionRecord authoritativeRecor return Status::OK(); } +StatusWith<SignedLogicalSessionId> LogicalSessionCache::signLsid(OperationContext* opCtx, + LogicalSessionId* id, + boost::optional<OID> userId) { + return _service->signLsid(opCtx, id, std::move(userId)); +} + +Status LogicalSessionCache::validateLsid(OperationContext* opCtx, + const SignedLogicalSessionId& slsid) { + return _service->validateLsid(opCtx, slsid); +} + void LogicalSessionCache::_refresh() { LogicalSessionIdSet activeSessions; LogicalSessionIdSet deadSessions; @@ -177,9 +201,9 @@ void LogicalSessionCache::_refresh() { for (auto& it : cacheCopy) { auto record = it.second; if (!_isDead(record, now)) { - activeSessions.insert(record.getLsid()); + activeSessions.insert(record.getSignedLsid().getLsid()); } else { - deadSessions.insert(record.getLsid()); + deadSessions.insert(record.getSignedLsid().getLsid()); } } @@ -236,7 +260,7 @@ bool LogicalSessionCache::_isDead(const LogicalSessionRecord& record, Date_t now boost::optional<LogicalSessionRecord> LogicalSessionCache::_addToCache( LogicalSessionRecord record) { stdx::unique_lock<stdx::mutex> lk(_cacheMutex); - return _cache.add(record.getLsid(), std::move(record)); + return _cache.add(record.getSignedLsid().getLsid(), std::move(record)); } } // namespace mongo diff --git a/src/mongo/db/logical_session_cache.h b/src/mongo/db/logical_session_cache.h index a88bb3f31d0..56ce96eec34 100644 --- a/src/mongo/db/logical_session_cache.h +++ b/src/mongo/db/logical_session_cache.h @@ -33,12 +33,17 @@ #include "mongo/db/logical_session_record.h" #include "mongo/db/service_liason.h" #include "mongo/db/sessions_collection.h" +#include "mongo/db/signed_logical_session_id.h" +#include "mongo/db/time_proof_service.h" #include "mongo/platform/atomic_word.h" #include "mongo/stdx/thread.h" #include "mongo/util/lru_cache.h" namespace mongo { +class OperationContext; +class ServiceContext; + extern int logicalSessionRecordCacheSize; extern int localLogicalSessionTimeoutMinutes; extern int logicalSessionRefreshMinutes; @@ -51,6 +56,13 @@ extern int logicalSessionRefreshMinutes; */ class LogicalSessionCache { public: + /** + * Decorate the ServiceContext with a LogicalSessionCache instance. + */ + static LogicalSessionCache* get(ServiceContext* service); + static LogicalSessionCache* get(OperationContext* opCtx); + static void set(ServiceContext* service, std::unique_ptr<LogicalSessionCache> sessionCache); + static constexpr int kLogicalSessionCacheDefaultCapacity = 10000; static constexpr Minutes kLogicalSessionDefaultTimeout = Minutes(30); static constexpr Minutes kLogicalSessionDefaultRefresh = Minutes(5); @@ -102,26 +114,22 @@ public: ~LogicalSessionCache(); /** - * Returns the owner for the given session, or return an error if there - * is no authoritative record for this session. - * - * If the cache does not already contain a record for this session, this - * method may issue networking operations to obtain the record. Afterwards, - * the cache will keep the record for future use. + * If the cache contains a record for this LogicalSessionId, promotes that lsid + * to be the most recently used and updates its lastUse date to be the current + * time. Otherwise, returns an error. * - * This call will promote any record it touches to be the most-recently-used - * record in the cache. + * This method does not issue networking calls. */ - StatusWith<LogicalSessionRecord::Owner> getOwner(LogicalSessionId lsid); + Status promote(SignedLogicalSessionId lsid); /** - * Returns the owner for the given session if we already have its record in the - * cache. Do not fetch the record from the network if we do not already have it. + * If the cache contains a record for this LogicalSessionId, promotes it. + * Otherwise, attempts to fetch the record for this LogicalSessionId from the + * sessions collection, and returns the record if found. Otherwise, returns an error. * - * This call will promote any record it touches to be the most-recently-used - * record in the cache. + * This method may issue networking calls. */ - StatusWith<LogicalSessionRecord::Owner> getOwnerFromCache(LogicalSessionId lsid); + Status fetchAndPromote(SignedLogicalSessionId lsid); /** * Inserts a new authoritative session record into the cache. This method will @@ -129,7 +137,22 @@ public: * should only be used when starting new sessions and should not be used to * insert records for existing sessions. */ - Status startSession(LogicalSessionRecord authoritativeRecord); + Status startSession(SignedLogicalSessionId lsid); + + /** + * Generates and sets a signature for the fields in this LogicalSessionId. + * + * If this method is not able to acquire a key to perform the signature + * this call will return an error. + */ + StatusWith<SignedLogicalSessionId> signLsid(OperationContext* opCtx, + LogicalSessionId* id, + boost::optional<OID> userId); + + /** + * Validates that this LogicalSessionId was signed with the correct key. + */ + Status validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& lsid); /** * Removes all local records in this cache. Does not remove the corresponding diff --git a/src/mongo/db/logical_session_cache_test.cpp b/src/mongo/db/logical_session_cache_test.cpp index a9d3c5eb1c1..baba6814b34 100644 --- a/src/mongo/db/logical_session_cache_test.cpp +++ b/src/mongo/db/logical_session_cache_test.cpp @@ -57,9 +57,7 @@ class LogicalSessionCacheTest : public unittest::Test { public: LogicalSessionCacheTest() : _service(std::make_shared<MockServiceLiasonImpl>()), - _sessions(std::make_shared<MockSessionsCollectionImpl>()), - _user("sam", "test"), - _userId(OID::gen()) {} + _sessions(std::make_shared<MockSessionsCollectionImpl>()) {} void setUp() override { auto mockService = stdx::make_unique<MockServiceLiason>(_service); @@ -78,11 +76,6 @@ public: } } - LogicalSessionRecord newRecord() { - return LogicalSessionRecord::makeAuthoritativeRecord( - LogicalSessionId::gen(), _user, _userId, _service->now()); - } - std::unique_ptr<LogicalSessionCache>& cache() { return _cache; } @@ -100,142 +93,132 @@ private: std::shared_ptr<MockSessionsCollectionImpl> _sessions; std::unique_ptr<LogicalSessionCache> _cache; - - UserName _user; - boost::optional<OID> _userId; }; // Test that session cache fetches new records from the sessions collection TEST_F(LogicalSessionCacheTest, CacheFetchesNewRecords) { - auto record = newRecord(); - auto lsid = record.getLsid(); + auto signedLsid = SignedLogicalSessionId::gen(); // When the record is not present (and not in the sessions collection) returns an error - auto res = cache()->getOwner(lsid); + auto res = cache()->fetchAndPromote(signedLsid); ASSERT(!res.isOK()); // When the record is not present (but is in the sessions collection) returns it - sessions()->add(record); - res = cache()->getOwner(lsid); + sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, service()->now())); + res = cache()->fetchAndPromote(signedLsid); ASSERT(res.isOK()); - ASSERT(res.getValue() == record.getSessionOwner()); // When the record is present in the cache, returns it - sessions()->setFetchHook([](LogicalSessionId id) -> StatusWith<LogicalSessionRecord> { + sessions()->setFetchHook([](SignedLogicalSessionId id) -> StatusWith<LogicalSessionRecord> { // We should not be querying the sessions collection on the next call ASSERT(false); return {ErrorCodes::NoSuchSession, "no such session"}; }); - res = cache()->getOwner(lsid); + res = cache()->fetchAndPromote(signedLsid); ASSERT(res.isOK()); - ASSERT(res.getValue() == record.getSessionOwner()); } // Test that the getFromCache method does not make calls to the sessions collection TEST_F(LogicalSessionCacheTest, TestCacheHitsOnly) { - auto record = newRecord(); - auto lsid = record.getLsid(); + auto signedLsid = SignedLogicalSessionId::gen(); // When the record is not present (and not in the sessions collection), returns an error - auto res = cache()->getOwnerFromCache(lsid); + auto res = cache()->promote(signedLsid); ASSERT(!res.isOK()); // When the record is not present (but is in the sessions collection), returns an error - sessions()->add(record); - res = cache()->getOwnerFromCache(lsid); + sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, service()->now())); + res = cache()->promote(signedLsid); ASSERT(!res.isOK()); // When the record is present, returns the owner - cache()->getOwner(lsid).status_with_transitional_ignore(); - res = cache()->getOwnerFromCache(lsid); + cache()->fetchAndPromote(signedLsid).transitional_ignore(); + res = cache()->promote(signedLsid); ASSERT(res.isOK()); - auto fetched = res.getValue(); - ASSERT(res.getValue() == record.getSessionOwner()); } // Test that fetching from the cache updates the lastUse date of records TEST_F(LogicalSessionCacheTest, FetchUpdatesLastUse) { - auto record = newRecord(); - auto lsid = record.getLsid(); + auto signedLsid = SignedLogicalSessionId::gen(); auto start = service()->now(); // Insert the record into the sessions collection with 'start' - record.setLastUse(start); - sessions()->add(record); + sessions()->add(LogicalSessionRecord::makeAuthoritativeRecord(signedLsid, start)); // Fast forward time and fetch service()->fastForward(Milliseconds(500)); ASSERT(start != service()->now()); - auto res = cache()->getOwner(lsid); + auto res = cache()->fetchAndPromote(signedLsid); ASSERT(res.isOK()); // Now that we fetched, lifetime of session should be extended service()->fastForward(kSessionTimeout - Milliseconds(500)); - res = cache()->getOwner(lsid); + res = cache()->fetchAndPromote(signedLsid); ASSERT(res.isOK()); // We fetched again, so lifetime extended again service()->fastForward(kSessionTimeout - Milliseconds(10)); - res = cache()->getOwner(lsid); + res = cache()->fetchAndPromote(signedLsid); ASSERT(res.isOK()); // Fast forward and hit-only fetch service()->fastForward(kSessionTimeout - Milliseconds(10)); - res = cache()->getOwnerFromCache(lsid); + res = cache()->promote(signedLsid); ASSERT(res.isOK()); // Lifetime extended again service()->fastForward(Milliseconds(11)); - res = cache()->getOwnerFromCache(lsid); + res = cache()->promote(signedLsid); ASSERT(res.isOK()); // Let record expire, we should not be able to get it from the cache service()->fastForward(kSessionTimeout + Milliseconds(1)); - res = cache()->getOwnerFromCache(lsid); + res = cache()->promote(signedLsid); ASSERT(!res.isOK()); } // Test the startSession method TEST_F(LogicalSessionCacheTest, StartSession) { - auto record = newRecord(); - auto lsid = record.getLsid(); + auto signedLsid = SignedLogicalSessionId::gen(); // Test starting a new session - auto res = cache()->startSession(record); + auto res = cache()->startSession(signedLsid); ASSERT(res.isOK()); - ASSERT(sessions()->has(lsid)); + ASSERT(sessions()->has(signedLsid.getLsid())); // Try to start a session that is already in the sessions collection and our // local cache, should fail - res = cache()->startSession(record); + res = cache()->startSession(signedLsid); ASSERT(!res.isOK()); // Try to start a session that is already in the sessions collection but // is not in our local cache, should fail - auto record2 = newRecord(); - sessions()->add(record2); - res = cache()->startSession(record2); + auto record2 = LogicalSessionRecord::makeAuthoritativeRecord(SignedLogicalSessionId::gen(), + service()->now()); + auto signedLsid2 = record2.getSignedLsid(); + sessions()->add(std::move(record2)); + res = cache()->startSession(signedLsid2); ASSERT(!res.isOK()); // Try to start a session that has expired from our cache, and is no // longer in the sessions collection, should succeed service()->fastForward(Milliseconds(kSessionTimeout.count() + 5)); - sessions()->remove(lsid); - ASSERT(!sessions()->has(lsid)); - res = cache()->startSession(record); + sessions()->remove(signedLsid.getLsid()); + ASSERT(!sessions()->has(signedLsid.getLsid())); + res = cache()->startSession(signedLsid); ASSERT(res.isOK()); - ASSERT(sessions()->has(lsid)); + ASSERT(sessions()->has(signedLsid.getLsid())); } // Test that records in the cache are properly refreshed until they expire TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) { // Insert two records into the cache - auto record1 = newRecord(); - auto record2 = newRecord(); - cache()->startSession(record1).transitional_ignore(); - cache()->startSession(record2).transitional_ignore(); + auto signedLsid1 = SignedLogicalSessionId::gen(); + auto signedLsid2 = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid1).transitional_ignore(); + cache()->startSession(signedLsid2).transitional_ignore(); stdx::promise<int> hitRefresh; auto refreshFuture = hitRefresh.get_future(); @@ -258,8 +241,7 @@ TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) { auto refresh2Future = refresh2.get_future(); // Use one of the records - auto lsid = record1.getLsid(); - auto res = cache()->getOwner(lsid); + auto res = cache()->fetchAndPromote(signedLsid1); ASSERT(res.isOK()); // Advance time so that one record expires @@ -276,25 +258,25 @@ TEST_F(LogicalSessionCacheTest, CacheRefreshesOwnRecords) { service()->fastForward(kSessionTimeout - kForceRefresh + Milliseconds(1)); refresh2Future.wait(); - ASSERT_EQ(refresh2Future.get(), lsid); + ASSERT_EQ(refresh2Future.get(), signedLsid1.getLsid()); } // Test that cache deletes records that fail to refresh TEST_F(LogicalSessionCacheTest, CacheDeletesRecordsThatFailToRefresh) { // Put two sessions into the cache - auto record1 = newRecord(); - auto record2 = newRecord(); - cache()->startSession(record1).transitional_ignore(); - cache()->startSession(record2).transitional_ignore(); + auto signedLsid1 = SignedLogicalSessionId::gen(); + auto signedLsid2 = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid1).transitional_ignore(); + cache()->startSession(signedLsid2).transitional_ignore(); stdx::promise<void> hitRefresh; auto refreshFuture = hitRefresh.get_future(); // Record 1 fails to refresh - sessions()->setRefreshHook([&hitRefresh, &record1](LogicalSessionIdSet sessions) { + sessions()->setRefreshHook([&hitRefresh, &signedLsid1](LogicalSessionIdSet sessions) { ASSERT_EQ(sessions.size(), size_t(2)); hitRefresh.set_value(); - return LogicalSessionIdSet{record1.getLsid()}; + return LogicalSessionIdSet{signedLsid1.getLsid()}; }); // Force a refresh @@ -302,72 +284,71 @@ TEST_F(LogicalSessionCacheTest, CacheDeletesRecordsThatFailToRefresh) { refreshFuture.wait(); // Ensure that one record is still there and the other is gone - auto res = cache()->getOwnerFromCache(record1.getLsid()); + auto res = cache()->promote(signedLsid1); ASSERT(!res.isOK()); - res = cache()->getOwnerFromCache(record2.getLsid()); + res = cache()->promote(signedLsid2); ASSERT(res.isOK()); } // Test that we don't remove records that fail to refresh if they are active on the service TEST_F(LogicalSessionCacheTest, KeepActiveSessionAliveEvenIfRefreshFails) { // Put two sessions into the cache, one into the service - auto record1 = newRecord(); - auto record2 = newRecord(); - cache()->startSession(record1).transitional_ignore(); - service()->add(record1.getLsid()); - cache()->startSession(record2).transitional_ignore(); + auto signedLsid1 = SignedLogicalSessionId::gen(); + auto signedLsid2 = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid1).transitional_ignore(); + service()->add(signedLsid1.getLsid()); + cache()->startSession(signedLsid2).transitional_ignore(); stdx::promise<void> hitRefresh; auto refreshFuture = hitRefresh.get_future(); - // Record 1 fails to refresh - sessions()->setRefreshHook([&hitRefresh, &record1](LogicalSessionIdSet sessions) { + // SignedLsid 1 fails to refresh + sessions()->setRefreshHook([&hitRefresh, &signedLsid1](LogicalSessionIdSet sessions) { ASSERT_EQ(sessions.size(), size_t(2)); hitRefresh.set_value(); - return LogicalSessionIdSet{record1.getLsid()}; + return LogicalSessionIdSet{signedLsid1.getLsid()}; }); // Force a refresh service()->fastForward(kForceRefresh); refreshFuture.wait(); - // Ensure that both records are still there - auto res = cache()->getOwnerFromCache(record1.getLsid()); + // Ensure that both signedLsids are still there + auto res = cache()->promote(signedLsid1); ASSERT(res.isOK()); - res = cache()->getOwnerFromCache(record2.getLsid()); + res = cache()->promote(signedLsid2); ASSERT(res.isOK()); } -// Test that session cache properly expires records after 30 minutes of no use +// Test that session cache properly expires signedLsids after 30 minutes of no use TEST_F(LogicalSessionCacheTest, BasicSessionExpiration) { - // Insert a record - auto record = newRecord(); - cache()->startSession(record).transitional_ignore(); - auto res = cache()->getOwnerFromCache(record.getLsid()); + // Insert a signedLsid + auto signedLsid = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid).transitional_ignore(); + auto res = cache()->promote(signedLsid); ASSERT(res.isOK()); // Force it to expire service()->fastForward(Milliseconds(kSessionTimeout.count() + 5)); // Check that it is no longer in the cache - res = cache()->getOwnerFromCache(record.getLsid()); + res = cache()->promote(signedLsid); ASSERT(!res.isOK()); } // Test that we keep refreshing sessions that are active on the service TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) { - auto record = newRecord(); - auto lsid = record.getLsid(); + auto signedLsid = SignedLogicalSessionId::gen(); - // Insert one active record on the service, none in the cache - service()->add(lsid); + // Insert one active signedLsid on the service, none in the cache + service()->add(signedLsid.getLsid()); stdx::mutex mutex; stdx::condition_variable cv; int count = 0; - sessions()->setRefreshHook([&cv, &mutex, &count, &lsid](LogicalSessionIdSet sessions) { - ASSERT_EQ(*(sessions.begin()), lsid); + sessions()->setRefreshHook([&cv, &mutex, &count, &signedLsid](LogicalSessionIdSet sessions) { + ASSERT_EQ(*(sessions.begin()), signedLsid.getLsid()); { stdx::unique_lock<stdx::mutex> lk(mutex); count++; @@ -397,7 +378,7 @@ TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) { // Wait until the next job has been scheduled waitUntilRefreshScheduled(); - // Force another refresh, check that it refreshes that active record again + // Force another refresh, check that it refreshes that active signedLsid again service()->fastForward(kForceRefresh); { stdx::unique_lock<stdx::mutex> lk(mutex); @@ -405,18 +386,18 @@ TEST_F(LogicalSessionCacheTest, LongRunningQueriesAreRefreshed) { } } -// Test that the set of records we refresh is a sum of cached + active records -TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceRecordsTogether) { +// Test that the set of signedLsids we refresh is a sum of cached + active signedLsids +TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceSignedLsidsTogether) { // Put one session into the cache, one into the service - auto record1 = newRecord(); - service()->add(record1.getLsid()); - auto record2 = newRecord(); - cache()->startSession(record2).transitional_ignore(); + auto signedLsid1 = SignedLogicalSessionId::gen(); + service()->add(signedLsid1.getLsid()); + auto signedLsid2 = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid2).transitional_ignore(); stdx::promise<void> hitRefresh; auto refreshFuture = hitRefresh.get_future(); - // Both records refresh + // Both signedLsids refresh sessions()->setRefreshHook([&hitRefresh](LogicalSessionIdSet sessions) { ASSERT_EQ(sessions.size(), size_t(2)); hitRefresh.set_value(); @@ -428,18 +409,18 @@ TEST_F(LogicalSessionCacheTest, RefreshCachedAndServiceRecordsTogether) { refreshFuture.wait(); } -// Test large sets of cache-only session records -TEST_F(LogicalSessionCacheTest, ManyRecordsInCacheRefresh) { +// Test large sets of cache-only session signedLsids +TEST_F(LogicalSessionCacheTest, ManySignedLsidsInCacheRefresh) { int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity; for (int i = 0; i < count; i++) { - auto record = newRecord(); - cache()->startSession(record).transitional_ignore(); + auto signedLsid = SignedLogicalSessionId::gen(); + cache()->startSession(signedLsid).transitional_ignore(); } stdx::promise<void> hitRefresh; auto refreshFuture = hitRefresh.get_future(); - // Check that all records refresh + // Check that all signedLsids refresh sessions()->setRefreshHook([&hitRefresh, &count](LogicalSessionIdSet sessions) { ASSERT_EQ(sessions.size(), size_t(count)); hitRefresh.set_value(); @@ -451,18 +432,18 @@ TEST_F(LogicalSessionCacheTest, ManyRecordsInCacheRefresh) { refreshFuture.wait(); } -// Test larger sets of service-only session records +// Test larger sets of service-only session signedLsids TEST_F(LogicalSessionCacheTest, ManyLongRunningSessionsRefresh) { int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity; for (int i = 0; i < count; i++) { - auto record = newRecord(); - service()->add(record.getLsid()); + auto lsid = LogicalSessionId::gen(); + service()->add(lsid); } stdx::promise<void> hitRefresh; auto refreshFuture = hitRefresh.get_future(); - // Check that all records refresh + // Check that all signedLsids refresh sessions()->setRefreshHook([&hitRefresh, &count](LogicalSessionIdSet sessions) { ASSERT_EQ(sessions.size(), size_t(count)); hitRefresh.set_value(); @@ -478,11 +459,11 @@ TEST_F(LogicalSessionCacheTest, ManyLongRunningSessionsRefresh) { TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) { int count = LogicalSessionCache::kLogicalSessionCacheDefaultCapacity; for (int i = 0; i < count; i++) { - auto record = newRecord(); - service()->add(record.getLsid()); + auto lsid = LogicalSessionId::gen(); + service()->add(lsid); - auto record2 = newRecord(); - cache()->startSession(record2).transitional_ignore(); + auto lsid2 = SignedLogicalSessionId::gen(); + cache()->startSession(lsid2).transitional_ignore(); } stdx::mutex mutex; @@ -490,7 +471,7 @@ TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) { int refreshes = 0; int nRefreshed = 0; - // Check that all records refresh successfully + // Check that all signedLsids refresh successfully sessions()->setRefreshHook( [&refreshes, &mutex, &cv, &nRefreshed](LogicalSessionIdSet sessions) { { @@ -550,7 +531,7 @@ TEST_F(LogicalSessionCacheTest, ManySessionsRefreshComboDeluxe) { cv.wait(lk, [&refreshes] { return refreshes == 3; }); } - // Since all but one record failed to refresh, third set should just have one record + // Since all but one signedLsid failed to refresh, third set should just have one signedLsid ASSERT_EQ(nRefreshed, 1); } diff --git a/src/mongo/db/logical_session_id.h b/src/mongo/db/logical_session_id.h index ef73bc0fcd1..230b892d934 100644 --- a/src/mongo/db/logical_session_id.h +++ b/src/mongo/db/logical_session_id.h @@ -46,21 +46,25 @@ const TxnNumber kUninitializedTxnNumber = -1; class BSONObjBuilder; /** - * A 128-bit identifier for a logical session. + * A 128-bit unique identifier for a logical session. */ class LogicalSessionId : public Logical_session_id { public: LogicalSessionId(); LogicalSessionId(Logical_session_id&& lsid); + friend class Logical_session_id; + friend class Logical_session_record; + friend class SignedLogicalSessionId; + /** * Create and return a new LogicalSessionId with a random UUID. */ static LogicalSessionId gen(); /** - * If the given string represents a valid LogicalSessionId, constructs and returns, - * the id, otherwise returns an error. + * If the given string represents a valid UUID, constructs and returns + * a new LogicalSessionId. Otherwise returns an error. */ static StatusWith<LogicalSessionId> parse(const std::string& s); diff --git a/src/mongo/db/logical_session_id_test.cpp b/src/mongo/db/logical_session_id_test.cpp index 13f3375ab33..c9e86639c4a 100644 --- a/src/mongo/db/logical_session_id_test.cpp +++ b/src/mongo/db/logical_session_id_test.cpp @@ -59,7 +59,10 @@ TEST(LogicalSessionIdTest, ToAndFromStringTest) { TEST(LogicalSessionIdTest, FromBSONTest) { auto uuid = UUID::gen(); - auto bson = BSON("id" << uuid.toBSON()); + + BSONObjBuilder b; + b.append("id", uuid.toBSON()); + auto bson = b.done(); auto lsid = LogicalSessionId::parse(bson); ASSERT_EQUALS(lsid.toString(), uuid.toString()); @@ -73,7 +76,6 @@ TEST(LogicalSessionIdTest, FromBSONTest) { << "there")), UserException); - // TODO: use these and add more once there is bindata ASSERT_THROWS(LogicalSessionId::parse(BSON("id" << "not a session id!")), UserException); diff --git a/src/mongo/db/logical_session_record.cpp b/src/mongo/db/logical_session_record.cpp index d6904a36f5f..51ca629962b 100644 --- a/src/mongo/db/logical_session_record.cpp +++ b/src/mongo/db/logical_session_record.cpp @@ -39,25 +39,17 @@ namespace mongo { StatusWith<LogicalSessionRecord> LogicalSessionRecord::parse(const BSONObj& bson) { try { IDLParserErrorContext ctxt("logical session record"); - LogicalSessionRecord record; record.parseProtected(ctxt, bson); - - auto owner = record.getOwner(); - UserName user{owner.getUserName(), owner.getDbName()}; - record._owner = std::make_pair(user, owner.getUserId()); - return record; } catch (std::exception e) { return exceptionToStatus(); } } -LogicalSessionRecord LogicalSessionRecord::makeAuthoritativeRecord(LogicalSessionId id, - UserName user, - boost::optional<OID> userId, +LogicalSessionRecord LogicalSessionRecord::makeAuthoritativeRecord(SignedLogicalSessionId id, Date_t now) { - return LogicalSessionRecord(std::move(id), std::move(user), std::move(userId), std::move(now)); + return LogicalSessionRecord(std::move(id), std::move(now)); } BSONObj LogicalSessionRecord::toBSON() const { @@ -68,28 +60,13 @@ BSONObj LogicalSessionRecord::toBSON() const { std::string LogicalSessionRecord::toString() const { return str::stream() << "LogicalSessionRecord" - << " Id: '" << getLsid() << "'" - << " Owner name: '" << getOwner().getUserName() << "'" + << " Id: '" << getSignedLsid() << "'" << " Last-use: " << getLastUse().toString(); } -LogicalSessionRecord::LogicalSessionRecord(LogicalSessionId id, - UserName user, - boost::optional<OID> userId, - Date_t now) - : _owner(std::make_pair(std::move(user), std::move(userId))) { - setLsid(std::move(id)); +LogicalSessionRecord::LogicalSessionRecord(SignedLogicalSessionId id, Date_t now) { + setSignedLsid(std::move(id)); setLastUse(now); - - Session_owner owner; - owner.setUserName(_owner.first.getUser()); - owner.setDbName(_owner.first.getDB()); - owner.setUserId(_owner.second); - setOwner(std::move(owner)); -} - -LogicalSessionRecord::Owner LogicalSessionRecord::getSessionOwner() const { - return _owner; } } // namespace mongo diff --git a/src/mongo/db/logical_session_record.h b/src/mongo/db/logical_session_record.h index e576cd82eb5..6d8662327ab 100644 --- a/src/mongo/db/logical_session_record.h +++ b/src/mongo/db/logical_session_record.h @@ -32,10 +32,8 @@ #include <utility> #include "mongo/bson/bsonobj.h" -#include "mongo/bson/oid.h" -#include "mongo/db/auth/user_name.h" -#include "mongo/db/logical_session_id.h" #include "mongo/db/logical_session_record_gen.h" +#include "mongo/db/signed_logical_session_id.h" #include "mongo/util/time_support.h" namespace mongo { @@ -46,17 +44,12 @@ namespace mongo { * The BSON representation of a session record follows this form: * * { - * lsid : LogicalSessionId, - * lastUse : Date_t, - * owner : { - * user : UserName, - * userId : OID - * } + * lsid : SignedLogicalSessionId, + * lastUse : Date_t + * } */ class LogicalSessionRecord : public Logical_session_record { public: - using Owner = std::pair<UserName, boost::optional<OID>>; - /** * Constructs and returns a LogicalSessionRecord from a BSON representation, * or throws an error. For IDL. @@ -69,10 +62,7 @@ public: * only be used when the caller is intending to make a new authoritative * record and subsequently insert that record into the sessions collection. */ - static LogicalSessionRecord makeAuthoritativeRecord(LogicalSessionId id, - UserName user, - boost::optional<OID> userId, - Date_t now); + static LogicalSessionRecord makeAuthoritativeRecord(SignedLogicalSessionId id, Date_t now); /** * Return a BSON representation of this session record. @@ -85,34 +75,17 @@ public: std::string toString() const; inline bool operator==(const LogicalSessionRecord& rhs) const { - return getLsid() == rhs.getLsid() && getSessionOwner() == rhs.getSessionOwner(); + return getSignedLsid() == rhs.getSignedLsid(); } inline bool operator!=(const LogicalSessionRecord& rhs) const { return !(*this == rhs); } - /** - * Return the username and id of the User who owns this session. Only a User - * that matches both the name and id returned by this method should be - * permitted to use this session. - * - * Note: if the returned optional OID is set to boost::none, this implies that - * the owning user is a pre-3.6 user that has no id. In this case, only a User - * with a matching UserName who also has an unset optional id should be - * permitted to use this session. - */ - Owner getSessionOwner() const; - private: LogicalSessionRecord() = default; - LogicalSessionRecord(LogicalSessionId id, - UserName user, - boost::optional<OID> userId, - Date_t now); - - Owner _owner; + LogicalSessionRecord(SignedLogicalSessionId id, Date_t now); }; inline std::ostream& operator<<(std::ostream& s, const LogicalSessionRecord& record) { diff --git a/src/mongo/db/logical_session_record.idl b/src/mongo/db/logical_session_record.idl index ac4df1c3b6c..25e2ac65b76 100644 --- a/src/mongo/db/logical_session_record.idl +++ b/src/mongo/db/logical_session_record.idl @@ -21,40 +21,28 @@ global: cpp_namespace: "mongo" cpp_includes: - "mongo/db/logical_session_id.h" + - "mongo/db/signed_logical_session_id.h" imports: - "mongo/idl/basic_types.idl" - - "mongo/db/logical_session_id.idl" + - "mongo/db/signed_logical_session_id.idl" types: - - LogicalSessionIdIDL: - description: "IDL representation of the LogicalSessionId cpp type" + SignedLogicalSessionIdIDL: + description: "IDL representation of the SignedLogicalSessionId cpp type" bson_serialization_type: object - cpp_type: "mongo::LogicalSessionId" - deserializer: "mongo::LogicalSessionId::parse" - serializer: "mongo::LogicalSessionId::toBSON" + cpp_type: "mongo::SignedLogicalSessionId" + deserializer: "mongo::SignedLogicalSessionId::parse" + serializer: "mongo::SignedLogicalSessionId::toBSON" structs: - session_owner: - description: "A sub-document of a session record with owner info" - fields: - userName: string - dbName: string - userId: - type: objectid - optional: true - logical_session_record: description: "A struct representing a LogicalSessionRecord" fields: - lsid: + signedLsid: description: "The id for this session record" - type: LogicalSessionIdIDL + type: SignedLogicalSessionIdIDL lastUse: description: "The time at which this record was last used. Note: the date expressed in this in-memory record object may not match the date on the authoritative record for this session, which is stored in the sessions collection." type: date - owner: - description: "The username and id of the User who owns this session." - type: session_owner diff --git a/src/mongo/db/logical_session_record_test.cpp b/src/mongo/db/logical_session_record_test.cpp index a307095a21b..659d23efa5b 100644 --- a/src/mongo/db/logical_session_record_test.cpp +++ b/src/mongo/db/logical_session_record_test.cpp @@ -44,32 +44,18 @@ namespace { TEST(LogicalSessionRecordTest, ToAndFromBSONTest) { // Round-trip a BSON obj auto lsid = LogicalSessionId::gen(); + SignedLogicalSessionId slsid{lsid, boost::none, 1, SHA1Block{}}; auto lastUse = Date_t::now(); - auto oid = OID::gen(); - - auto owner = BSON("userName" - << "Sam" - << "dbName" - << "test" - << "userId" - << oid); - auto bson = BSON("lsid" << lsid.toBSON() << "lastUse" << lastUse << "owner" << owner); + auto bson = BSON("signedLsid" << slsid.toBSON() << "lastUse" << lastUse); // Make a session record out of this auto res = LogicalSessionRecord::parse(bson); ASSERT(res.isOK()); auto record = res.getValue(); - ASSERT_EQ(record.getLsid(), lsid); + ASSERT_EQ(record.getSignedLsid(), slsid); ASSERT_EQ(record.getLastUse(), lastUse); - auto recordOwner = record.getSessionOwner(); - - ASSERT_EQ(recordOwner.first.getUser(), "Sam"); - ASSERT_EQ(recordOwner.first.getDB(), "test"); - ASSERT(recordOwner.second); - ASSERT_EQ(*(recordOwner.second), oid); - // Dump back to bson, make sure we get the same thing auto bsonDump = record.toBSON(); ASSERT_EQ(bsonDump.woCompare(bson), 0); diff --git a/src/mongo/db/logical_time_validator.cpp b/src/mongo/db/logical_time_validator.cpp index d17a728435c..98e4f740ff6 100644 --- a/src/mongo/db/logical_time_validator.cpp +++ b/src/mongo/db/logical_time_validator.cpp @@ -35,7 +35,7 @@ #include "mongo/db/auth/action_type.h" #include "mongo/db/auth/authorization_session.h" #include "mongo/db/auth/privilege.h" -#include "mongo/db/keys_collection_manager.h" +#include "mongo/db/keys_collection_manager_sharding.h" #include "mongo/db/operation_context.h" #include "mongo/db/service_context.h" #include "mongo/util/assert_util.h" @@ -77,8 +77,9 @@ void LogicalTimeValidator::set(ServiceContext* service, validator = std::move(newValidator); } -LogicalTimeValidator::LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager) - : _keyManager(std::move(keyManager)) {} +LogicalTimeValidator::LogicalTimeValidator( + std::shared_ptr<KeysCollectionManagerSharding> keyManager) + : _keyManager(keyManager) {} SignedLogicalTime LogicalTimeValidator::_getProof(const KeysCollectionDocument& keyDoc, LogicalTime newTime) { @@ -103,7 +104,7 @@ SignedLogicalTime LogicalTimeValidator::_getProof(const KeysCollectionDocument& } SignedLogicalTime LogicalTimeValidator::trySignLogicalTime(const LogicalTime& newTime) { - auto keyStatusWith = _keyManager->getKeyForSigning(newTime); + auto keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime); auto keyStatus = keyStatusWith.getStatus(); if (keyStatus == ErrorCodes::KeyNotFound) { @@ -117,13 +118,13 @@ SignedLogicalTime LogicalTimeValidator::trySignLogicalTime(const LogicalTime& ne SignedLogicalTime LogicalTimeValidator::signLogicalTime(OperationContext* opCtx, const LogicalTime& newTime) { - auto keyStatusWith = _keyManager->getKeyForSigning(newTime); + auto keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime); auto keyStatus = keyStatusWith.getStatus(); while (keyStatus == ErrorCodes::KeyNotFound) { _keyManager->refreshNow(opCtx); - keyStatusWith = _keyManager->getKeyForSigning(newTime); + keyStatusWith = _keyManager->getKeyForSigning(nullptr, newTime); keyStatus = keyStatusWith.getStatus(); if (keyStatus == ErrorCodes::KeyNotFound) { diff --git a/src/mongo/db/logical_time_validator.h b/src/mongo/db/logical_time_validator.h index 44bb34922e5..7d11450bf1c 100644 --- a/src/mongo/db/logical_time_validator.h +++ b/src/mongo/db/logical_time_validator.h @@ -39,7 +39,7 @@ namespace mongo { class OperationContext; class ServiceContext; class KeysCollectionDocument; -class KeysCollectionManager; +class KeysCollectionManagerSharding; /** * This is responsible for signing cluster times that can be used to sent to other servers and @@ -52,7 +52,11 @@ public: static LogicalTimeValidator* get(OperationContext* ctx); static void set(ServiceContext* service, std::unique_ptr<LogicalTimeValidator> validator); - explicit LogicalTimeValidator(std::unique_ptr<KeysCollectionManager> keyManager); + /** + * Constructs a new LogicalTimeValidator that uses the given key manager. The passed-in + * key manager must outlive this object. + */ + explicit LogicalTimeValidator(std::shared_ptr<KeysCollectionManagerSharding> keyManager); /** * Tries to sign the newTime with a valid signature. Can return an empty signature and keyId @@ -103,7 +107,7 @@ private: stdx::mutex _mutex; SignedLogicalTime _lastSeenValidTime; TimeProofService _timeProofService; - std::unique_ptr<KeysCollectionManager> _keyManager; + std::shared_ptr<KeysCollectionManagerSharding> _keyManager; }; } // namespace mongo diff --git a/src/mongo/db/logical_time_validator_test.cpp b/src/mongo/db/logical_time_validator_test.cpp index 9a4c289d983..8292ee75db4 100644 --- a/src/mongo/db/logical_time_validator_test.cpp +++ b/src/mongo/db/logical_time_validator_test.cpp @@ -30,6 +30,7 @@ #include "mongo/bson/timestamp.h" #include "mongo/db/keys_collection_manager.h" +#include "mongo/db/keys_collection_manager_sharding.h" #include "mongo/db/logical_clock.h" #include "mongo/db/logical_time.h" #include "mongo/db/logical_time_validator.h" @@ -70,10 +71,9 @@ protected: const LogicalTime currentTime(LogicalTime(Timestamp(1, 0))); LogicalClock::get(operationContext())->setClusterTimeFromTrustedSource(currentTime); - auto keyManager = - stdx::make_unique<KeysCollectionManager>("dummy", catalogClient, Seconds(1000)); - _keyManager = keyManager.get(); - _validator = stdx::make_unique<LogicalTimeValidator>(std::move(keyManager)); + _keyManager = + std::make_shared<KeysCollectionManagerSharding>("dummy", catalogClient, Seconds(1000)); + _validator = stdx::make_unique<LogicalTimeValidator>(_keyManager); _validator->init(operationContext()->getServiceContext()); } @@ -97,7 +97,7 @@ protected: private: std::unique_ptr<LogicalTimeValidator> _validator; - KeysCollectionManager* _keyManager; + std::shared_ptr<KeysCollectionManagerSharding> _keyManager; }; TEST_F(LogicalTimeValidatorTest, GetTimeWithIncreasingTimes) { @@ -133,7 +133,7 @@ TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTime) { refreshKeyManager(); auto newTime = validator()->trySignLogicalTime(t1); - TimeProofService::TimeProof invalidProof = {{1, 2, 3}}; + TimeProofService::TimeProof invalidProof = {{{1, 2, 3}}}; SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId()); // ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException, // ErrorCodes::TimeProofMismatch); @@ -156,7 +156,7 @@ TEST_F(LogicalTimeValidatorTest, ValidateErrorsOnInvalidTimeWithImplicitRefresh) LogicalTime t1(Timestamp(20, 0)); auto newTime = validator()->signLogicalTime(operationContext(), t1); - TimeProofService::TimeProof invalidProof = {{1, 2, 3}}; + TimeProofService::TimeProof invalidProof = {{{1, 2, 3}}}; SignedLogicalTime invalidTime(LogicalTime(Timestamp(30, 0)), invalidProof, newTime.getKeyId()); // ASSERT_THROWS_CODE(validator()->validate(operationContext(), invalidTime), DBException, // ErrorCodes::TimeProofMismatch); diff --git a/src/mongo/db/service_context.cpp b/src/mongo/db/service_context.cpp index f4083be9901..05d89c4cc6f 100644 --- a/src/mongo/db/service_context.cpp +++ b/src/mongo/db/service_context.cpp @@ -161,15 +161,6 @@ PeriodicRunner* ServiceContext::getPeriodicRunner() const { return _runner.get(); } -void ServiceContext::setLogicalSessionCache(std::unique_ptr<LogicalSessionCache> cache)& { - invariant(!_sessionCache); - _sessionCache = std::move(cache); -} - -LogicalSessionCache* ServiceContext::getLogicalSessionCache() const& { - return _sessionCache.get(); -} - transport::TransportLayer* ServiceContext::getTransportLayer() const { return _transportLayer.get(); } diff --git a/src/mongo/db/service_context.h b/src/mongo/db/service_context.h index 3cf830c3fe7..9f97445d69f 100644 --- a/src/mongo/db/service_context.h +++ b/src/mongo/db/service_context.h @@ -31,7 +31,8 @@ #include <vector> #include "mongo/base/disallow_copying.h" -#include "mongo/db/logical_session_cache.h" +#include "mongo/db/keys_collection_manager.h" +#include "mongo/db/logical_session_id.h" #include "mongo/db/storage/storage_engine.h" #include "mongo/platform/atomic_word.h" #include "mongo/platform/unordered_set.h" @@ -266,6 +267,28 @@ public: virtual StorageEngine* getGlobalStorageEngine() = 0; // + // Key manager, for HMAC keys. + // + + /** + * Sets the key manager on this service context. + */ + void setKeyManager(std::shared_ptr<KeysCollectionManager> keyManager) & { + stdx::lock_guard<stdx::mutex> lk(_mutex); + _keyManager = std::move(keyManager); + } + + /** + * Returns a pointer to the keys collection manager owned by this service context. + */ + std::shared_ptr<KeysCollectionManager> getKeyManager() & { + stdx::lock_guard<stdx::mutex> lk(_mutex); + return _keyManager; + } + + std::shared_ptr<KeysCollectionManager> getKeyManager() && = delete; + + // // Global operation management. This may not belong here and there may be too many methods // here. // @@ -329,21 +352,6 @@ public: PeriodicRunner* getPeriodicRunner() const; // - // Logical sessions. - // - - /** - * Set the logical session cache on this service context. - */ - void setLogicalSessionCache(std::unique_ptr<LogicalSessionCache> cache) &; - - /** - * Return a pointer to the logical session cache on this service context. - */ - LogicalSessionCache* getLogicalSessionCache() const&; - LogicalSessionCache* getLogicalSessionCache() && = delete; - - // // Transport. // @@ -459,14 +467,14 @@ private: void _killOperation_inlock(OperationContext* opCtx, ErrorCodes::Error killCode); /** - * The periodic runner. + * The key manager. */ - std::unique_ptr<PeriodicRunner> _runner; + std::shared_ptr<KeysCollectionManager> _keyManager; /** - * The logical session cache. + * The periodic runner. */ - std::unique_ptr<LogicalSessionCache> _sessionCache; + std::unique_ptr<PeriodicRunner> _runner; /** * The TransportLayer. diff --git a/src/mongo/db/service_liason.cpp b/src/mongo/db/service_liason.cpp index e12a243b6e2..a4a576c17f6 100644 --- a/src/mongo/db/service_liason.cpp +++ b/src/mongo/db/service_liason.cpp @@ -30,8 +30,79 @@ #include "mongo/db/service_liason.h" +#include "mongo/db/keys_collection_manager_zero.h" +#include "mongo/db/logical_clock.h" +#include "mongo/db/service_context.h" + namespace mongo { +namespace { + +const int kSignatureSize = sizeof(UUID) + sizeof(OID); + +SHA1Block computeSignature(const SignedLogicalSessionId* id, TimeProofService::Key key) { + // Write the uuid and user id to a block for signing. + char signatureBlock[kSignatureSize] = {0}; + DataRangeCursor cursor(signatureBlock, signatureBlock + kSignatureSize); + auto res = cursor.writeAndAdvance<ConstDataRange>(id->getLsid().getId().toCDR()); + invariant(res.isOK()); + if (auto userId = id->getUserId()) { + res = cursor.writeAndAdvance<ConstDataRange>(userId->toCDR()); + invariant(res.isOK()); + } + + // Compute the signature. + return SHA1Block::computeHmac( + key.data(), key.size(), reinterpret_cast<uint8_t*>(signatureBlock), kSignatureSize); +} + +KeysCollectionManagerZero kKeysCollectionManagerZero{"HMAC"}; + +} // namespace + ServiceLiason::~ServiceLiason() = default; +StatusWith<SignedLogicalSessionId> ServiceLiason::signLsid(OperationContext* opCtx, + LogicalSessionId* lsid, + boost::optional<OID> userId) { + auto& keyManager = kKeysCollectionManagerZero; + + auto logicalTime = LogicalClock::get(_context())->getClusterTime(); + auto res = keyManager.getKeyForSigning(opCtx, logicalTime); + if (!res.isOK()) { + return res.getStatus(); + } + + SignedLogicalSessionId signedLsid; + signedLsid.setUserId(std::move(userId)); + signedLsid.setLsid(*lsid); + + auto keyDoc = res.getValue(); + signedLsid.setKeyId(keyDoc.getKeyId()); + + auto signature = computeSignature(&signedLsid, keyDoc.getKey()); + signedLsid.setSignature(std::move(signature)); + + return signedLsid; +} + +Status ServiceLiason::validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& id) { + auto& keyManager = kKeysCollectionManagerZero; + + // Attempt to get the correct key. + auto logicalTime = LogicalClock::get(_context())->getClusterTime(); + auto res = keyManager.getKeyForValidation(opCtx, id.getKeyId(), logicalTime); + if (!res.isOK()) { + return res.getStatus(); + } + + // Re-compute the signature, and see that it matches. + auto signature = computeSignature(&id, res.getValue().getKey()); + if (signature != id.getSignature()) { + return {ErrorCodes::NoSuchSession, "Signature validation failed."}; + } + + return Status::OK(); +} + } // namespace mongo diff --git a/src/mongo/db/service_liason.h b/src/mongo/db/service_liason.h index 2389245f6a5..dca11f50182 100644 --- a/src/mongo/db/service_liason.h +++ b/src/mongo/db/service_liason.h @@ -29,12 +29,15 @@ #pragma once #include "mongo/db/logical_session_id.h" +#include "mongo/db/signed_logical_session_id.h" #include "mongo/stdx/functional.h" #include "mongo/util/periodic_runner.h" #include "mongo/util/time_support.h" namespace mongo { +class ServiceContext; + /** * A service-dependent type for the LogicalSessionCache to use to find the * current time, schedule periodic refresh jobs, and get a list of sessions @@ -73,6 +76,27 @@ public: * Return the current time. */ virtual Date_t now() const = 0; + + /** + * Generates and sets a signature for the fields in this LogicalSessionId. + * + * If this method is not able to acquire a key to perform the signature + * this call will return an error. + */ + StatusWith<SignedLogicalSessionId> signLsid(OperationContext* opCtx, + LogicalSessionId* lsid, + boost::optional<OID> userId); + + /** + * Validates that this LogicalSessionId was signed with the correct key. + */ + Status validateLsid(OperationContext* opCtx, const SignedLogicalSessionId& id); + +protected: + /** + * Returns the service context. + */ + virtual ServiceContext* _context() = 0; }; } // namespace mongo diff --git a/src/mongo/db/service_liason_mock.h b/src/mongo/db/service_liason_mock.h index bfd2d281ca1..39323057e59 100644 --- a/src/mongo/db/service_liason_mock.h +++ b/src/mongo/db/service_liason_mock.h @@ -28,6 +28,8 @@ #pragma once +#include "mongo/db/service_context.h" +#include "mongo/db/service_context_noop.h" #include "mongo/db/service_liason.h" #include "mongo/executor/async_timer_mock.h" #include "mongo/platform/atomic_word.h" @@ -101,8 +103,14 @@ public: return _impl->join(); } +protected: + ServiceContext* _context() override { + return _serviceContext.get(); + } + private: std::shared_ptr<MockServiceLiasonImpl> _impl; + std::unique_ptr<ServiceContextNoop> _serviceContext; }; } // namespace mongo diff --git a/src/mongo/db/service_liason_mongod.cpp b/src/mongo/db/service_liason_mongod.cpp index fdbf811de18..08c49127936 100644 --- a/src/mongo/db/service_liason_mongod.cpp +++ b/src/mongo/db/service_liason_mongod.cpp @@ -88,4 +88,8 @@ Date_t ServiceLiasonMongod::now() const { return getGlobalServiceContext()->getFastClockSource()->now(); } +ServiceContext* ServiceLiasonMongod::_context() { + return getGlobalServiceContext(); +} + } // namespace mongo diff --git a/src/mongo/db/service_liason_mongod.h b/src/mongo/db/service_liason_mongod.h index 21c677feb8c..31304b94573 100644 --- a/src/mongo/db/service_liason_mongod.h +++ b/src/mongo/db/service_liason_mongod.h @@ -57,6 +57,12 @@ public: void join() override; Date_t now() const override; + +protected: + /** + * Returns the service context. + */ + ServiceContext* _context() override; }; } // namespace mongo diff --git a/src/mongo/db/sessions_collection.h b/src/mongo/db/sessions_collection.h index ec59f28c597..7692e634275 100644 --- a/src/mongo/db/sessions_collection.h +++ b/src/mongo/db/sessions_collection.h @@ -30,6 +30,7 @@ #include "mongo/db/logical_session_id.h" #include "mongo/db/logical_session_record.h" +#include "mongo/db/signed_logical_session_id.h" namespace mongo { @@ -44,10 +45,10 @@ public: virtual ~SessionsCollection(); /** - * Returns a LogicalSessionRecord for the given LogicalSessionId. This method + * Returns a LogicalSessionRecord for the given session id. This method * may run networking operations on the calling thread. */ - virtual StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid) = 0; + virtual StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id) = 0; /** * Inserts the given record into the sessions collection. This method may run diff --git a/src/mongo/db/sessions_collection_mock.cpp b/src/mongo/db/sessions_collection_mock.cpp index 4abc5a326d1..1079605badd 100644 --- a/src/mongo/db/sessions_collection_mock.cpp +++ b/src/mongo/db/sessions_collection_mock.cpp @@ -65,8 +65,9 @@ void MockSessionsCollectionImpl::clearHooks() { _remove = stdx::bind(&MockSessionsCollectionImpl::_removeRecords, this, stdx::placeholders::_1); } -StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::fetchRecord(LogicalSessionId lsid) { - return _fetch(std::move(lsid)); +StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::fetchRecord( + SignedLogicalSessionId id) { + return _fetch(std::move(id)); } Status MockSessionsCollectionImpl::insertRecord(LogicalSessionRecord record) { @@ -83,7 +84,7 @@ void MockSessionsCollectionImpl::removeRecords(LogicalSessionIdSet sessions) { void MockSessionsCollectionImpl::add(LogicalSessionRecord record) { stdx::unique_lock<stdx::mutex> lk(_mutex); - _sessions.insert({record.getLsid(), std::move(record)}); + _sessions.insert({record.getSignedLsid().getLsid(), std::move(record)}); } void MockSessionsCollectionImpl::remove(LogicalSessionId lsid) { @@ -105,11 +106,12 @@ const MockSessionsCollectionImpl::SessionMap& MockSessionsCollectionImpl::sessio return _sessions; } -StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord(LogicalSessionId lsid) { +StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord( + SignedLogicalSessionId id) { stdx::unique_lock<stdx::mutex> lk(_mutex); // If we do not have this record, return an error - auto it = _sessions.find(lsid); + auto it = _sessions.find(id.getLsid()); if (it == _sessions.end()) { return {ErrorCodes::NoSuchSession, "No matching record in the sessions collection"}; } @@ -119,7 +121,7 @@ StatusWith<LogicalSessionRecord> MockSessionsCollectionImpl::_fetchRecord(Logica Status MockSessionsCollectionImpl::_insertRecord(LogicalSessionRecord record) { stdx::unique_lock<stdx::mutex> lk(_mutex); - auto res = _sessions.insert({record.getLsid(), std::move(record)}); + auto res = _sessions.insert({record.getSignedLsid().getLsid(), std::move(record)}); // We should never try to insert the same record twice. In theory this could // happen because of a UUID conflict. diff --git a/src/mongo/db/sessions_collection_mock.h b/src/mongo/db/sessions_collection_mock.h index 64e2a76fc7b..afa04bfd9db 100644 --- a/src/mongo/db/sessions_collection_mock.h +++ b/src/mongo/db/sessions_collection_mock.h @@ -59,7 +59,7 @@ public: MockSessionsCollectionImpl(); - using FetchHook = stdx::function<StatusWith<LogicalSessionRecord>(LogicalSessionId)>; + using FetchHook = stdx::function<StatusWith<LogicalSessionRecord>(SignedLogicalSessionId)>; using InsertHook = stdx::function<Status(LogicalSessionRecord)>; using RefreshHook = stdx::function<LogicalSessionIdSet(LogicalSessionIdSet)>; using RemoveHook = stdx::function<void(LogicalSessionIdSet)>; @@ -74,7 +74,7 @@ public: void clearHooks(); // Forwarding methods from the MockSessionsCollection - StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid); + StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id); Status insertRecord(LogicalSessionRecord record); LogicalSessionIdSet refreshSessions(LogicalSessionIdSet sessions); void removeRecords(LogicalSessionIdSet sessions); @@ -88,7 +88,7 @@ public: private: // Default implementations, may be overridden with custom hooks. - StatusWith<LogicalSessionRecord> _fetchRecord(LogicalSessionId lsid); + StatusWith<LogicalSessionRecord> _fetchRecord(SignedLogicalSessionId id); Status _insertRecord(LogicalSessionRecord record); LogicalSessionIdSet _refreshSessions(LogicalSessionIdSet sessions); void _removeRecords(LogicalSessionIdSet sessions); @@ -112,8 +112,8 @@ public: explicit MockSessionsCollection(std::shared_ptr<MockSessionsCollectionImpl> impl) : _impl(std::move(impl)) {} - StatusWith<LogicalSessionRecord> fetchRecord(LogicalSessionId lsid) override { - return _impl->fetchRecord(std::move(lsid)); + StatusWith<LogicalSessionRecord> fetchRecord(SignedLogicalSessionId id) override { + return _impl->fetchRecord(std::move(id)); } Status insertRecord(LogicalSessionRecord record) override { diff --git a/src/mongo/db/signed_logical_session_id.cpp b/src/mongo/db/signed_logical_session_id.cpp new file mode 100644 index 00000000000..a284fb6344b --- /dev/null +++ b/src/mongo/db/signed_logical_session_id.cpp @@ -0,0 +1,73 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/signed_logical_session_id.h" + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/util/assert_util.h" + +namespace mongo { + +SignedLogicalSessionId SignedLogicalSessionId::gen() { + return SignedLogicalSessionId(); +} + +SignedLogicalSessionId::SignedLogicalSessionId() { + setLsid(LogicalSessionId::gen()); +} + +SignedLogicalSessionId::SignedLogicalSessionId(LogicalSessionId lsid, + boost::optional<OID> userId, + long long keyId, + SHA1Block signature) { + setLsid(std::move(lsid)); + setUserId(std::move(userId)); + setKeyId(std::move(keyId)); + setSignature(std::move(signature)); +} + +SignedLogicalSessionId SignedLogicalSessionId::parse(const BSONObj& doc) { + IDLParserErrorContext ctx("signed logical session id"); + SignedLogicalSessionId lsid; + lsid.parseProtected(ctx, doc); + return lsid; +} + +BSONObj SignedLogicalSessionId::toBSON() const { + BSONObjBuilder builder; + serialize(&builder); + return builder.obj(); +} + +std::string SignedLogicalSessionId::toString() const { + return getLsid().toString(); +} + +} // namespace mongo diff --git a/src/mongo/db/signed_logical_session_id.h b/src/mongo/db/signed_logical_session_id.h new file mode 100644 index 00000000000..632f645c623 --- /dev/null +++ b/src/mongo/db/signed_logical_session_id.h @@ -0,0 +1,115 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <string> + +#include "mongo/base/status_with.h" +#include "mongo/bson/oid.h" +#include "mongo/db/logical_session_id.h" +#include "mongo/db/signed_logical_session_id_gen.h" +#include "mongo/stdx/unordered_set.h" +#include "mongo/util/uuid.h" + +namespace mongo { + +class BSONObjBuilder; +class OperationContext; + +/** + * An identifier for a logical session. A LogicalSessionId has the following components: + * + * - A 128-bit unique identifier (UUID) + * - An optional user id (ObjectId) + * - A key id (long long) + * - An HMAC signature (SHA1Block) + */ +class SignedLogicalSessionId : public Signed_logical_session_id { +public: + using Owner = boost::optional<OID>; + + friend class Logical_session_id; + friend class Logical_session_record; + + using keyIdType = long long; + + /** + * Create and return a new LogicalSessionId with a random UUID. This method + * should be used for testing only. The generated SignedLogicalSessionId will + * not be signed, and will have no owner. + */ + static SignedLogicalSessionId gen(); + + /** + * Creates a new SignedLogicalSessionId. + */ + SignedLogicalSessionId(LogicalSessionId lsid, + boost::optional<OID> userId, + long long keyId, + SHA1Block signature); + + /** + * Constructs a new LogicalSessionId out of a BSONObj. For IDL. + */ + static SignedLogicalSessionId parse(const BSONObj& doc); + + /** + * Returns a string representation of this session id. + */ + std::string toString() const; + + /** + * Serialize this object to BSON. + */ + BSONObj toBSON() const; + + inline bool operator==(const SignedLogicalSessionId& rhs) const { + return getLsid() == rhs.getLsid() && getUserId() == rhs.getUserId() && + getKeyId() == rhs.getKeyId() && getSignature() == rhs.getSignature(); + } + + inline bool operator!=(const SignedLogicalSessionId& rhs) const { + return !(*this == rhs); + } + + /** + * This constructor exists for IDL only. + */ + SignedLogicalSessionId(); +}; + +inline std::ostream& operator<<(std::ostream& s, const SignedLogicalSessionId& lsid) { + return (s << lsid.toString()); +} + +inline StringBuilder& operator<<(StringBuilder& s, const SignedLogicalSessionId& lsid) { + return (s << lsid.toString()); +} + +} // namespace mongo diff --git a/src/mongo/db/signed_logical_session_id.idl b/src/mongo/db/signed_logical_session_id.idl new file mode 100644 index 00000000000..634d4731a52 --- /dev/null +++ b/src/mongo/db/signed_logical_session_id.idl @@ -0,0 +1,50 @@ +# Copyright (C) 2017 MongoDB Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License, version 3, +# as published by the Free Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see <http://www.gnu.org/licenses/>. +# + +# This IDL file describes the BSON format for a LogicalSessionId, and +# handles the serialization to and deserialization from its BSON representation +# for that class. + +global: + cpp_namespace: "mongo" + cpp_includes: + - "mongo/util/uuid.h" + - "mongo/db/logical_session_id.h" + +imports: + - "mongo/crypto/sha1_block.idl" + - "mongo/db/logical_session_id.idl" + - "mongo/idl/basic_types.idl" + +types: + LogicalSessionIdIdl: + description: "IDL representation of the LogicalSessionId cpp type" + bson_serialization_type: object + cpp_type: "mongo::LogicalSessionId" + deserializer: "mongo::LogicalSessionId::parse" + serializer: "mongo::LogicalSessionId::toBSON" + +structs: + + signed_logical_session_id: + description: "A struct representing a SignedLogicalSessionId" + strict: true + fields: + lsid: LogicalSessionIdIdl + userId: + optional: true + type: objectid + keyId: long + signature: sha1Block diff --git a/src/mongo/db/signed_logical_session_id_test.cpp b/src/mongo/db/signed_logical_session_id_test.cpp new file mode 100644 index 00000000000..2622d65f2f8 --- /dev/null +++ b/src/mongo/db/signed_logical_session_id_test.cpp @@ -0,0 +1,81 @@ +/** + * Copyright (C) 2017 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include <string> + +#include "mongo/bson/bsonmisc.h" +#include "mongo/bson/bsonobj.h" +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/bson/bsontypes.h" +#include "mongo/crypto/sha1_block.h" +#include "mongo/db/logical_session_id.h" +#include "mongo/db/signed_logical_session_id.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/uuid.h" + +namespace mongo { +namespace { + +TEST(SignedLogicalSessionIdTest, ConstructWithLsid) { + auto lsid = LogicalSessionId::gen(); + SignedLogicalSessionId slsid(lsid, boost::none, 1, SHA1Block{}); + ASSERT_EQ(slsid.getLsid(), lsid); +} + +TEST(SignedLogicalSessionIdTest, FromBSONTest) { + auto lsid = LogicalSessionId::gen(); + + BSONObjBuilder b; + b.append("lsid", lsid.toBSON()); + b.append("keyId", 4ll); + char buffer[SHA1Block::kHashLength] = {0}; + b.appendBinData("signature", SHA1Block::kHashLength, BinDataGeneral, buffer); + auto bson = b.done(); + + auto slsid = SignedLogicalSessionId::parse(bson); + ASSERT_EQ(slsid.getLsid(), lsid); + + // Dump back to BSON, make sure we get the same thing + auto bsonDump = slsid.toBSON(); + ASSERT_EQ(bsonDump.woCompare(bson), 0); + + // Try parsing mal-formatted bson objs + ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("hi" + << "there")), + UserException); + + ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("lsid" + << "not a session id!")), + UserException); + ASSERT_THROWS(SignedLogicalSessionId::parse(BSON("lsid" << 14)), UserException); +} + +} // namespace +} // namespace mongo |