diff options
Diffstat (limited to 'src/mongo/db/s/database_sharding_state.cpp')
-rw-r--r-- | src/mongo/db/s/database_sharding_state.cpp | 57 |
1 files changed, 36 insertions, 21 deletions
diff --git a/src/mongo/db/s/database_sharding_state.cpp b/src/mongo/db/s/database_sharding_state.cpp index 74351ce4336..67a6f306f27 100644 --- a/src/mongo/db/s/database_sharding_state.cpp +++ b/src/mongo/db/s/database_sharding_state.cpp @@ -98,41 +98,56 @@ const ServiceContext::Decoration<DatabaseShardingStateMap> DatabaseShardingState } // namespace -DatabaseShardingState::ScopedDatabaseShardingState::ScopedDatabaseShardingState( +DatabaseShardingState::DatabaseShardingState(const DatabaseName& dbName) : _dbName(dbName) {} + +DatabaseShardingState::ScopedExclusiveDatabaseShardingState::ScopedExclusiveDatabaseShardingState( Lock::ResourceLock lock, DatabaseShardingState* dss) : _lock(std::move(lock)), _dss(dss) {} +DatabaseShardingState::ScopedSharedDatabaseShardingState::ScopedSharedDatabaseShardingState( + Lock::ResourceLock lock, DatabaseShardingState* dss) + : DatabaseShardingState::ScopedExclusiveDatabaseShardingState(std::move(lock), dss) {} -DatabaseShardingState::ScopedDatabaseShardingState::ScopedDatabaseShardingState( - ScopedDatabaseShardingState&& other) - : _lock(std::move(other._lock)), _dss(other._dss) { - other._dss = nullptr; -} - -DatabaseShardingState::ScopedDatabaseShardingState::~ScopedDatabaseShardingState() = default; +DatabaseShardingState::ScopedExclusiveDatabaseShardingState DatabaseShardingState::acquireExclusive( + OperationContext* opCtx, const DatabaseName& dbName) { -DatabaseShardingState::DatabaseShardingState(const DatabaseName& dbName) : _dbName(dbName) {} + DatabaseShardingStateMap::DSSAndLock* dssAndLock = + DatabaseShardingStateMap::get(opCtx->getServiceContext()).getOrCreate(dbName); -DatabaseShardingState::ScopedDatabaseShardingState DatabaseShardingState::assertDbLockedAndAcquire( - OperationContext* opCtx, const DatabaseName& dbName, DSSAcquisitionMode mode) { - dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS)); + // First lock the RESOURCE_MUTEX associated to this dbName to guarantee stability of the + // DatabaseShardingState pointer. After that, it is safe to get and store the + // DatabaseShadingState*, as long as the RESOURCE_MUTEX is kept locked. + Lock::ResourceLock lock(opCtx->lockState(), dssAndLock->dssMutex.getRid(), MODE_X); - return acquire(opCtx, dbName, mode); + return ScopedExclusiveDatabaseShardingState(std::move(lock), dssAndLock->dss.get()); } -DatabaseShardingState::ScopedDatabaseShardingState DatabaseShardingState::acquire( - OperationContext* opCtx, const DatabaseName& dbName, DSSAcquisitionMode mode) { +DatabaseShardingState::ScopedSharedDatabaseShardingState DatabaseShardingState::acquireShared( + OperationContext* opCtx, const DatabaseName& dbName) { + DatabaseShardingStateMap::DSSAndLock* dssAndLock = DatabaseShardingStateMap::get(opCtx->getServiceContext()).getOrCreate(dbName); // First lock the RESOURCE_MUTEX associated to this dbName to guarantee stability of the // DatabaseShardingState pointer. After that, it is safe to get and store the // DatabaseShadingState*, as long as the RESOURCE_MUTEX is kept locked. - Lock::ResourceLock lock(opCtx->lockState(), - dssAndLock->dssMutex.getRid(), - mode == DSSAcquisitionMode::kShared ? MODE_IS : MODE_X); + Lock::ResourceLock lock(opCtx->lockState(), dssAndLock->dssMutex.getRid(), MODE_IS); + + return ScopedSharedDatabaseShardingState(std::move(lock), dssAndLock->dss.get()); +} - return ScopedDatabaseShardingState(std::move(lock), dssAndLock->dss.get()); +DatabaseShardingState::ScopedExclusiveDatabaseShardingState +DatabaseShardingState::assertDbLockedAndAcquireExclusive(OperationContext* opCtx, + const DatabaseName& dbName) { + dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS)); + return acquireExclusive(opCtx, dbName); +} + +DatabaseShardingState::ScopedSharedDatabaseShardingState +DatabaseShardingState::assertDbLockedAndAcquireShared(OperationContext* opCtx, + const DatabaseName& dbName) { + dassert(opCtx->lockState()->isDbLockedForMode(dbName, MODE_IS)); + return acquireShared(opCtx, dbName); } std::vector<DatabaseName> DatabaseShardingState::getDatabaseNames(OperationContext* opCtx) { @@ -153,7 +168,7 @@ void DatabaseShardingState::assertMatchingDbVersion(OperationContext* opCtx, void DatabaseShardingState::assertMatchingDbVersion(OperationContext* opCtx, const DatabaseName& dbName, const DatabaseVersion& receivedVersion) { - const auto scopedDss = acquire(opCtx, dbName, DSSAcquisitionMode::kShared); + const auto scopedDss = acquireShared(opCtx, dbName); { const auto critSecSignal = scopedDss->getCriticalSectionSignal( @@ -192,7 +207,7 @@ void DatabaseShardingState::assertIsPrimaryShardForDb(OperationContext* opCtx, Lock::DBLock dbLock(opCtx, dbName, MODE_IS); assertMatchingDbVersion(opCtx, dbName); - const auto scopedDss = assertDbLockedAndAcquire(opCtx, dbName, DSSAcquisitionMode::kShared); + const auto scopedDss = assertDbLockedAndAcquireShared(opCtx, dbName); const auto primaryShardId = scopedDss->_dbInfo->getPrimary(); const auto thisShardId = ShardingState::get(opCtx)->shardId(); uassert(ErrorCodes::IllegalOperation, |