summaryrefslogtreecommitdiff
path: root/src/mongo/db/s/database_sharding_state.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/s/database_sharding_state.cpp')
-rw-r--r--src/mongo/db/s/database_sharding_state.cpp57
1 files changed, 36 insertions, 21 deletions
diff --git a/src/mongo/db/s/database_sharding_state.cpp b/src/mongo/db/s/database_sharding_state.cpp
index 74351ce4336..67a6f306f27 100644
--- a/src/mongo/db/s/database_sharding_state.cpp
+++ b/src/mongo/db/s/database_sharding_state.cpp
@@ -98,41 +98,56 @@ const ServiceContext::Decoration<DatabaseShardingStateMap> DatabaseShardingState
} // namespace
-DatabaseShardingState::ScopedDatabaseShardingState::ScopedDatabaseShardingState(
+DatabaseShardingState::DatabaseShardingState(const DatabaseName& dbName) : _dbName(dbName) {}
+
+DatabaseShardingState::ScopedExclusiveDatabaseShardingState::ScopedExclusiveDatabaseShardingState(
Lock::ResourceLock lock, DatabaseShardingState* dss)
: _lock(std::move(lock)), _dss(dss) {}
+DatabaseShardingState::ScopedSharedDatabaseShardingState::ScopedSharedDatabaseShardingState(
+ Lock::ResourceLock lock, DatabaseShardingState* dss)
+ : DatabaseShardingState::ScopedExclusiveDatabaseShardingState(std::move(lock), dss) {}
-DatabaseShardingState::ScopedDatabaseShardingState::ScopedDatabaseShardingState(
- ScopedDatabaseShardingState&& other)
- : _lock(std::move(other._lock)), _dss(other._dss) {
- other._dss = nullptr;
-}
-
-DatabaseShardingState::ScopedDatabaseShardingState::~ScopedDatabaseShardingState() = default;
+DatabaseShardingState::ScopedExclusiveDatabaseShardingState DatabaseShardingState::acquireExclusive(
+ OperationContext* opCtx, const DatabaseName& dbName) {
-DatabaseShardingState::DatabaseShardingState(const DatabaseName& dbName) : _dbName(dbName) {}
+ DatabaseShardingStateMap::DSSAndLock* dssAndLock =
+ DatabaseShardingStateMap::get(opCtx->getServiceContext()).getOrCreate(dbName);
-DatabaseShardingState::ScopedDatabaseShardingState DatabaseShardingState::assertDbLockedAndAcquire(
- OperationContext* opCtx, const DatabaseName& dbName, DSSAcquisitionMode mode) {
- dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS));
+ // First lock the RESOURCE_MUTEX associated to this dbName to guarantee stability of the
+ // DatabaseShardingState pointer. After that, it is safe to get and store the
+ // DatabaseShadingState*, as long as the RESOURCE_MUTEX is kept locked.
+ Lock::ResourceLock lock(opCtx->lockState(), dssAndLock->dssMutex.getRid(), MODE_X);
- return acquire(opCtx, dbName, mode);
+ return ScopedExclusiveDatabaseShardingState(std::move(lock), dssAndLock->dss.get());
}
-DatabaseShardingState::ScopedDatabaseShardingState DatabaseShardingState::acquire(
- OperationContext* opCtx, const DatabaseName& dbName, DSSAcquisitionMode mode) {
+DatabaseShardingState::ScopedSharedDatabaseShardingState DatabaseShardingState::acquireShared(
+ OperationContext* opCtx, const DatabaseName& dbName) {
+
DatabaseShardingStateMap::DSSAndLock* dssAndLock =
DatabaseShardingStateMap::get(opCtx->getServiceContext()).getOrCreate(dbName);
// First lock the RESOURCE_MUTEX associated to this dbName to guarantee stability of the
// DatabaseShardingState pointer. After that, it is safe to get and store the
// DatabaseShadingState*, as long as the RESOURCE_MUTEX is kept locked.
- Lock::ResourceLock lock(opCtx->lockState(),
- dssAndLock->dssMutex.getRid(),
- mode == DSSAcquisitionMode::kShared ? MODE_IS : MODE_X);
+ Lock::ResourceLock lock(opCtx->lockState(), dssAndLock->dssMutex.getRid(), MODE_IS);
+
+ return ScopedSharedDatabaseShardingState(std::move(lock), dssAndLock->dss.get());
+}
- return ScopedDatabaseShardingState(std::move(lock), dssAndLock->dss.get());
+DatabaseShardingState::ScopedExclusiveDatabaseShardingState
+DatabaseShardingState::assertDbLockedAndAcquireExclusive(OperationContext* opCtx,
+ const DatabaseName& dbName) {
+ dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS));
+ return acquireExclusive(opCtx, dbName);
+}
+
+DatabaseShardingState::ScopedSharedDatabaseShardingState
+DatabaseShardingState::assertDbLockedAndAcquireShared(OperationContext* opCtx,
+ const DatabaseName& dbName) {
+ dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS));
+ return acquireShared(opCtx, dbName);
}
std::vector<DatabaseName> DatabaseShardingState::getDatabaseNames(OperationContext* opCtx) {
@@ -153,7 +168,7 @@ void DatabaseShardingState::assertMatchingDbVersion(OperationContext* opCtx,
void DatabaseShardingState::assertMatchingDbVersion(OperationContext* opCtx,
const DatabaseName& dbName,
const DatabaseVersion& receivedVersion) {
- const auto scopedDss = acquire(opCtx, dbName, DSSAcquisitionMode::kShared);
+ const auto scopedDss = acquireShared(opCtx, dbName);
{
const auto critSecSignal = scopedDss->getCriticalSectionSignal(
@@ -192,7 +207,7 @@ void DatabaseShardingState::assertIsPrimaryShardForDb(OperationContext* opCtx,
Lock::DBLock dbLock(opCtx, dbName, MODE_IS);
assertMatchingDbVersion(opCtx, dbName);
- const auto scopedDss = assertDbLockedAndAcquire(opCtx, dbName, DSSAcquisitionMode::kShared);
+ const auto scopedDss = assertDbLockedAndAcquireShared(opCtx, dbName);
const auto primaryShardId = scopedDss->_dbInfo->getPrimary();
const auto thisShardId = ShardingState::get(opCtx)->shardId();
uassert(ErrorCodes::IllegalOperation,