diff options
author | Jennifer Peshansky <jennifer.peshansky@mongodb.com> | 2022-11-03 16:13:20 +0000 |
---|---|---|
committer | Jennifer Peshansky <jennifer.peshansky@mongodb.com> | 2022-11-03 16:13:20 +0000 |
commit | e74d2910bbe76790ad131d53fee277829cd95982 (patch) | |
tree | cabe148764529c9623652374fbc36323a550cd44 /src/mongo/db/s | |
parent | 280145e9940729480bb8a35453d4056afac87641 (diff) | |
parent | ba467f46cc1bc49965e1d72b541eff0cf1d7b22e (diff) | |
download | mongo-jenniferpeshansky/SERVER-70854.tar.gz |
Merge branch 'master' into jenniferpeshansky/SERVER-70854jenniferpeshansky/SERVER-70854
Diffstat (limited to 'src/mongo/db/s')
31 files changed, 2543 insertions, 190 deletions
diff --git a/src/mongo/db/s/SConscript b/src/mongo/db/s/SConscript index fa54218910d..3248c1b296d 100644 --- a/src/mongo/db/s/SConscript +++ b/src/mongo/db/s/SConscript @@ -11,7 +11,6 @@ env = env.Clone() env.Library( target='sharding_api_d', source=[ - 'balancer_stats_registry.cpp', 'collection_metadata.cpp', 'collection_sharding_state_factory_standalone.cpp', 'collection_sharding_state.cpp', @@ -36,14 +35,25 @@ env.Library( LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/catalog/index_catalog', '$BUILD_DIR/mongo/db/concurrency/lock_manager', - '$BUILD_DIR/mongo/db/dbdirectclient', - '$BUILD_DIR/mongo/db/repl/replica_set_aware_service', '$BUILD_DIR/mongo/db/server_base', '$BUILD_DIR/mongo/db/write_block_bypass', ], ) env.Library( + target='balancer_stats_registry', + source=[ + 'balancer_stats_registry.cpp', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/db/dbdirectclient', + '$BUILD_DIR/mongo/db/repl/replica_set_aware_service', + '$BUILD_DIR/mongo/db/shard_role', + '$BUILD_DIR/mongo/s/grid', + ], +) + +env.Library( target='sharding_catalog', source=[ 'global_index_ddl_util.cpp', @@ -64,6 +74,22 @@ env.Library( ) env.Library( + target="query_analysis_writer", + source=[ + "query_analysis_writer.cpp", + ], + LIBDEPS_PRIVATE=[ + "$BUILD_DIR/mongo/db/dbdirectclient", + '$BUILD_DIR/mongo/db/ops/write_ops_parsers', + '$BUILD_DIR/mongo/db/server_base', + '$BUILD_DIR/mongo/db/service_context', + '$BUILD_DIR/mongo/db/shard_role', + '$BUILD_DIR/mongo/idl/idl_parser', + '$BUILD_DIR/mongo/s/analyze_shard_key_common', + ], +) + +env.Library( target='sharding_runtime_d', source=[ 'active_migrations_registry.cpp', @@ -228,6 +254,8 @@ env.Library( '$BUILD_DIR/mongo/db/transaction/transaction_operations', '$BUILD_DIR/mongo/s/common_s', '$BUILD_DIR/mongo/util/future_util', + 'balancer_stats_registry', + 'query_analysis_writer', 'sharding_catalog', 'sharding_logging', ], @@ -550,6 +578,7 @@ env.Library( '$BUILD_DIR/mongo/s/sharding_initialization', '$BUILD_DIR/mongo/s/sharding_router_api', '$BUILD_DIR/mongo/s/startup_initialization', + 'balancer_stats_registry', 'forwardable_operation_metadata', 'sharding_catalog', 'sharding_logging', @@ -653,6 +682,7 @@ env.CppUnitTest( 'op_observer_sharding_test.cpp', 'operation_sharding_state_test.cpp', 'persistent_task_queue_test.cpp', + 'query_analysis_writer_test.cpp', 'range_deleter_service_test.cpp', 'range_deleter_service_test_util.cpp', 'range_deleter_service_op_observer_test.cpp', @@ -734,6 +764,7 @@ env.CppUnitTest( '$BUILD_DIR/mongo/executor/thread_pool_task_executor_test_fixture', '$BUILD_DIR/mongo/s/catalog/sharding_catalog_client_mock', '$BUILD_DIR/mongo/s/sharding_router_test_fixture', + 'query_analysis_writer', 'shard_server_test_fixture', 'sharding_catalog', 'sharding_commands_d', diff --git a/src/mongo/db/s/balancer/balancer.cpp b/src/mongo/db/s/balancer/balancer.cpp index f1c96cc97cf..9103ec7f04c 100644 --- a/src/mongo/db/s/balancer/balancer.cpp +++ b/src/mongo/db/s/balancer/balancer.cpp @@ -281,7 +281,7 @@ Balancer::Balancer() std::make_unique<BalancerChunkSelectionPolicyImpl>(_clusterStats.get(), _random)), _commandScheduler(std::make_unique<BalancerCommandsSchedulerImpl>()), _defragmentationPolicy(std::make_unique<BalancerDefragmentationPolicyImpl>( - _clusterStats.get(), _random, [this]() { _onActionsStreamPolicyStateUpdate(); })), + _clusterStats.get(), [this]() { _onActionsStreamPolicyStateUpdate(); })), _clusterChunksResizePolicy(std::make_unique<ClusterChunksResizePolicyImpl>( [this] { _onActionsStreamPolicyStateUpdate(); })) {} @@ -1085,7 +1085,7 @@ int Balancer::_moveChunks(OperationContext* opCtx, opCtx, migrateInfo.uuid, repl::ReadConcernLevel::kMajorityReadConcern); ShardingCatalogManager::get(opCtx)->splitOrMarkJumbo( - opCtx, collection.getNss(), migrateInfo.minKey); + opCtx, collection.getNss(), migrateInfo.minKey, migrateInfo.getMaxChunkSizeBytes()); continue; } @@ -1151,7 +1151,12 @@ BalancerCollectionStatusResponse Balancer::getBalancerStatusForNs(OperationConte uasserted(ErrorCodes::NamespaceNotSharded, "Collection unsharded or undefined"); } - const auto maxChunkSizeMB = getMaxChunkSizeMB(opCtx, coll); + + const auto maxChunkSizeBytes = getMaxChunkSizeBytes(opCtx, coll); + double maxChunkSizeMB = (double)maxChunkSizeBytes / (1024 * 1024); + // Keep only 2 decimal digits to return a readable value + maxChunkSizeMB = std::ceil(maxChunkSizeMB * 100.0) / 100.0; + BalancerCollectionStatusResponse response(maxChunkSizeMB, true /*balancerCompliant*/); auto setViolationOnResponse = [&response](const StringData& reason, const boost::optional<BSONObj>& details = diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp index 1c5a88a6100..c8d51184bc3 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp @@ -398,7 +398,8 @@ public: std::move(collectionChunks), std::move(shardInfos), std::move(collectionZones), - smallChunkSizeThresholdBytes)); + smallChunkSizeThresholdBytes, + maxChunkSizeBytes)); } DefragmentationPhaseEnum getType() const override { @@ -462,7 +463,7 @@ public: auto smallChunkVersion = getShardVersion(opCtx, nextSmallChunk->shard, _nss); _outstandingMigrations.emplace_back(nextSmallChunk, targetSibling); return _outstandingMigrations.back().asMigrateInfo( - _uuid, _nss, smallChunkVersion.placementVersion()); + _uuid, _nss, smallChunkVersion.placementVersion(), _maxChunkSizeBytes); } return boost::none; @@ -701,7 +702,8 @@ private: MigrateInfo asMigrateInfo(const UUID& collUuid, const NamespaceString& nss, - const ChunkVersion& version) const { + const ChunkVersion& version, + uint64_t maxChunkSizeBytes) const { return MigrateInfo(chunkToMergeWith->shard, chunkToMove->shard, nss, @@ -709,7 +711,8 @@ private: chunkToMove->range.getMin(), chunkToMove->range.getMax(), version, - ForceJumbo::kForceBalancer); + ForceJumbo::kForceBalancer, + maxChunkSizeBytes); } ChunkRange asMergedRange() const { @@ -774,6 +777,8 @@ private: const int64_t _smallChunkSizeThresholdBytes; + const uint64_t _maxChunkSizeBytes; + bool _aborted{false}; DefragmentationPhaseEnum _nextPhase{DefragmentationPhaseEnum::kMergeChunks}; @@ -783,7 +788,8 @@ private: std::vector<ChunkType>&& collectionChunks, stdx::unordered_map<ShardId, ShardInfo>&& shardInfos, ZoneInfo&& collectionZones, - uint64_t smallChunkSizeThresholdBytes) + uint64_t smallChunkSizeThresholdBytes, + uint64_t maxChunkSizeBytes) : _nss(nss), _uuid(uuid), _collectionChunks(), @@ -794,7 +800,8 @@ private: _actionableMerges(), _outstandingMerges(), _zoneInfo(std::move(collectionZones)), - _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes) { + _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes), + _maxChunkSizeBytes(maxChunkSizeBytes) { // Load the collection routing table in a std::list to ease later manipulation for (auto&& chunk : collectionChunks) { diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h index bc41346ca7f..bab4a6c58cf 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h @@ -72,9 +72,10 @@ class BalancerDefragmentationPolicyImpl : public BalancerDefragmentationPolicy { public: BalancerDefragmentationPolicyImpl(ClusterStatistics* clusterStats, - BalancerRandomSource& random, const std::function<void()>& onStateUpdated) - : _clusterStats(clusterStats), _random(random), _onStateUpdated(onStateUpdated) {} + : _clusterStats(clusterStats), + _random(std::random_device{}()), + _onStateUpdated(onStateUpdated) {} ~BalancerDefragmentationPolicyImpl() {} @@ -145,7 +146,7 @@ private: ClusterStatistics* const _clusterStats; - BalancerRandomSource& _random; + BalancerRandomSource _random; const std::function<void()> _onStateUpdated; diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp index 8646b5d965a..e9370bbb4d3 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp @@ -29,7 +29,6 @@ #include "mongo/db/dbdirectclient.h" #include "mongo/db/s/balancer/balancer_defragmentation_policy_impl.h" -#include "mongo/db/s/balancer/balancer_random.h" #include "mongo/db/s/balancer/cluster_statistics_mock.h" #include "mongo/db/s/config/config_server_test_fixture.h" #include "mongo/idl/server_parameter_test_util.h" @@ -75,9 +74,7 @@ protected: const std::function<void()> onDefragmentationStateUpdated = [] {}; BalancerDefragmentationPolicyTest() - : _clusterStats(), - _random(std::random_device{}()), - _defragmentationPolicy(&_clusterStats, _random, onDefragmentationStateUpdated) {} + : _clusterStats(), _defragmentationPolicy(&_clusterStats, onDefragmentationStateUpdated) {} CollectionType setupCollectionWithPhase( const std::vector<ChunkType>& chunkList, @@ -137,7 +134,6 @@ protected: } ClusterStatisticsMock _clusterStats; - BalancerRandomSource _random; BalancerDefragmentationPolicyImpl _defragmentationPolicy; ShardStatistics buildShardStats(ShardId id, diff --git a/src/mongo/db/s/balancer/balancer_policy.cpp b/src/mongo/db/s/balancer/balancer_policy.cpp index 4e0ff4c3ccb..a8ad96ab2c7 100644 --- a/src/mongo/db/s/balancer/balancer_policy.cpp +++ b/src/mongo/db/s/balancer/balancer_policy.cpp @@ -36,6 +36,7 @@ #include "mongo/db/s/balancer/type_migration.h" #include "mongo/logv2/log.h" +#include "mongo/s/balancer_configuration.h" #include "mongo/s/catalog/type_shard.h" #include "mongo/s/catalog/type_tags.h" #include "mongo/s/grid.h" @@ -450,7 +451,16 @@ MigrateInfosWithReason BalancerPolicy::balance( } invariant(to != stat.shardId); - migrations.emplace_back(to, distribution.nss(), chunk, ForceJumbo::kForceBalancer); + + auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> { + if (collDataSizeInfo.has_value()) { + return collDataSizeInfo->maxChunkSizeBytes; + } + return boost::none; + }(); + + migrations.emplace_back( + to, distribution.nss(), chunk, ForceJumbo::kForceBalancer, maxChunkSizeBytes); if (firstReason == MigrationReason::none) { firstReason = MigrationReason::drain; } @@ -513,11 +523,20 @@ MigrateInfosWithReason BalancerPolicy::balance( } invariant(to != stat.shardId); + + auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> { + if (collDataSizeInfo.has_value()) { + return collDataSizeInfo->maxChunkSizeBytes; + } + return boost::none; + }(); + migrations.emplace_back(to, distribution.nss(), chunk, forceJumbo ? ForceJumbo::kForceBalancer - : ForceJumbo::kDoNotForce); + : ForceJumbo::kDoNotForce, + maxChunkSizeBytes); if (firstReason == MigrationReason::none) { firstReason = MigrationReason::zoneViolation; } @@ -796,7 +815,8 @@ string ZoneRange::toString() const { MigrateInfo::MigrateInfo(const ShardId& a_to, const NamespaceString& a_nss, const ChunkType& a_chunk, - const ForceJumbo a_forceJumbo) + const ForceJumbo a_forceJumbo, + boost::optional<int64_t> maxChunkSizeBytes) : nss(a_nss), uuid(a_chunk.getCollectionUUID()) { invariant(a_to.isValid()); @@ -807,6 +827,7 @@ MigrateInfo::MigrateInfo(const ShardId& a_to, maxKey = a_chunk.getMax(); version = a_chunk.getVersion(); forceJumbo = a_forceJumbo; + optMaxChunkSizeBytes = maxChunkSizeBytes; } MigrateInfo::MigrateInfo(const ShardId& a_to, @@ -858,6 +879,10 @@ string MigrateInfo::toString() const { << ", to " << to; } +boost::optional<int64_t> MigrateInfo::getMaxChunkSizeBytes() const { + return optMaxChunkSizeBytes; +} + SplitInfo::SplitInfo(const ShardId& inShardId, const NamespaceString& inNss, const ChunkVersion& inCollectionVersion, diff --git a/src/mongo/db/s/balancer/balancer_policy.h b/src/mongo/db/s/balancer/balancer_policy.h index bd464ca9499..a9a1f307310 100644 --- a/src/mongo/db/s/balancer/balancer_policy.h +++ b/src/mongo/db/s/balancer/balancer_policy.h @@ -60,7 +60,8 @@ struct MigrateInfo { MigrateInfo(const ShardId& a_to, const NamespaceString& a_nss, const ChunkType& a_chunk, - ForceJumbo a_forceJumbo); + ForceJumbo a_forceJumbo, + boost::optional<int64_t> maxChunkSizeBytes = boost::none); MigrateInfo(const ShardId& a_to, const ShardId& a_from, @@ -78,6 +79,8 @@ struct MigrateInfo { std::string toString() const; + boost::optional<int64_t> getMaxChunkSizeBytes() const; + NamespaceString nss; UUID uuid; ShardId to; diff --git a/src/mongo/db/s/balancer_stats_registry.cpp b/src/mongo/db/s/balancer_stats_registry.cpp index bfd789acb52..0f11a578c23 100644 --- a/src/mongo/db/s/balancer_stats_registry.cpp +++ b/src/mongo/db/s/balancer_stats_registry.cpp @@ -29,6 +29,7 @@ #include "mongo/db/s/balancer_stats_registry.h" +#include "mongo/db/catalog_raii.h" #include "mongo/db/dbdirectclient.h" #include "mongo/db/pipeline/aggregate_command_gen.h" #include "mongo/db/repl/replication_coordinator.h" @@ -58,23 +59,6 @@ ThreadPool::Options makeDefaultThreadPoolOptions() { } } // namespace -ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx) - // TODO SERVER-62491 Use system tenantId for DBLock - : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX), - _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_X) {} - -// Take DB and Collection lock in mode IX as well as collection UUID lock to serialize with -// operations that take the above version of the ScopedRangeDeleterLock such as FCV downgrade and -// BalancerStatsRegistry initialization. -ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid) - // TODO SERVER-62491 Use system tenantId for DBLock - : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX), - _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_IX), - _collectionUuidLock(Lock::ResourceLock( - opCtx, - ResourceId(RESOURCE_MUTEX, "RangeDeleterCollLock::" + collectionUuid.toString()), - MODE_X)) {} - const ReplicaSetAwareServiceRegistry::Registerer<BalancerStatsRegistry> balancerStatsRegistryRegisterer("BalancerStatsRegistry"); @@ -131,9 +115,13 @@ void BalancerStatsRegistry::initializeAsync(OperationContext* opCtx) { LOGV2_DEBUG(6419601, 2, "Initializing BalancerStatsRegistry"); try { - // Lock the range deleter to prevent - // concurrent modifications of orphans count - ScopedRangeDeleterLock rangeDeleterLock(opCtx); + // Lock the range deleter to prevent concurrent modifications of orphans count + ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_S); + // The collection lock is needed to serialize with direct writes to + // config.rangeDeletions + AutoGetCollection rangeDeletionLock( + opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S); + // Load current ophans count from disk _loadOrphansCount(opCtx); LOGV2_DEBUG(6419602, 2, "Completed BalancerStatsRegistry initialization"); diff --git a/src/mongo/db/s/balancer_stats_registry.h b/src/mongo/db/s/balancer_stats_registry.h index 473285b627c..fdc68ab1021 100644 --- a/src/mongo/db/s/balancer_stats_registry.h +++ b/src/mongo/db/s/balancer_stats_registry.h @@ -38,18 +38,19 @@ namespace mongo { /** - * Acquires the config db lock in IX mode and the collection lock for config.rangeDeletions in X - * mode. + * Scoped lock to synchronize with the execution of range deletions. + * The range-deleter acquires a scoped lock in IX mode while orphans are being deleted. + * Acquiring the scoped lock in MODE_X ensures that no orphan counter in `config.rangeDeletions` + * entries is going to be updated concurrently. */ class ScopedRangeDeleterLock { public: - ScopedRangeDeleterLock(OperationContext* opCtx); - ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid); + ScopedRangeDeleterLock(OperationContext* opCtx, LockMode mode) + : _resourceLock(opCtx, _mutex.getRid(), mode) {} private: - Lock::DBLock _configLock; - Lock::CollectionLock _rangeDeletionLock; - boost::optional<Lock::ResourceLock> _collectionUuidLock; + const Lock::ResourceLock _resourceLock; + static inline const Lock::ResourceMutex _mutex{"ScopedRangeDeleterLock"}; }; /** diff --git a/src/mongo/db/s/cluster_count_cmd_d.cpp b/src/mongo/db/s/cluster_count_cmd_d.cpp index 27c64148cc9..a593a299b4f 100644 --- a/src/mongo/db/s/cluster_count_cmd_d.cpp +++ b/src/mongo/db/s/cluster_count_cmd_d.cpp @@ -62,6 +62,11 @@ struct ClusterCountCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster count command on a mongod"); + } }; ClusterCountCmdBase<ClusterCountCmdD> clusterCountCmdD; diff --git a/src/mongo/db/s/cluster_find_cmd_d.cpp b/src/mongo/db/s/cluster_find_cmd_d.cpp index e97f07b70e7..e98315c06ad 100644 --- a/src/mongo/db/s/cluster_find_cmd_d.cpp +++ b/src/mongo/db/s/cluster_find_cmd_d.cpp @@ -61,6 +61,11 @@ struct ClusterFindCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster find command on a mongod"); + } }; ClusterFindCmdBase<ClusterFindCmdD> clusterFindCmdD; diff --git a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp index 437e7418355..1b3eed4242a 100644 --- a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp +++ b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp @@ -61,6 +61,11 @@ struct ClusterPipelineCommandD { uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster aggregate command on a mongod"); + } + static AggregateCommandRequest parseAggregationRequest( OperationContext* opCtx, const OpMsgRequest& opMsgRequest, diff --git a/src/mongo/db/s/cluster_write_cmd_d.cpp b/src/mongo/db/s/cluster_write_cmd_d.cpp index c8f75d69e15..66167ab2b30 100644 --- a/src/mongo/db/s/cluster_write_cmd_d.cpp +++ b/src/mongo/db/s/cluster_write_cmd_d.cpp @@ -57,6 +57,11 @@ struct ClusterInsertCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster insert command on a mongod"); + } }; ClusterInsertCmdBase<ClusterInsertCmdD> clusterInsertCmdD; @@ -83,6 +88,10 @@ struct ClusterUpdateCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, "Explain on a clusterDelete is not supported"); + } }; ClusterUpdateCmdBase<ClusterUpdateCmdD> clusterUpdateCmdD; @@ -109,6 +118,11 @@ struct ClusterDeleteCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster delete command on a mongod"); + } }; ClusterDeleteCmdBase<ClusterDeleteCmdD> clusterDeleteCmdD; diff --git a/src/mongo/db/s/collection_sharding_runtime.cpp b/src/mongo/db/s/collection_sharding_runtime.cpp index 89ceb0cfd4c..ee530f6ea04 100644 --- a/src/mongo/db/s/collection_sharding_runtime.cpp +++ b/src/mongo/db/s/collection_sharding_runtime.cpp @@ -486,8 +486,10 @@ void CollectionShardingRuntime::resetShardVersionRecoverRefreshFuture() { _shardVersionInRecoverOrRefresh = boost::none; } -boost::optional<Timestamp> CollectionShardingRuntime::getIndexVersion(OperationContext* opCtx) { - return _globalIndexesInfo ? _globalIndexesInfo->getVersion() : boost::none; +boost::optional<CollectionIndexes> CollectionShardingRuntime::getCollectionIndexes( + OperationContext* opCtx) { + return _globalIndexesInfo ? boost::make_optional(_globalIndexesInfo->getCollectionIndexes()) + : boost::none; } boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes( @@ -497,22 +499,22 @@ boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes( void CollectionShardingRuntime::addIndex(OperationContext* opCtx, const IndexCatalogType& index, - const Timestamp& indexVersion) { + const CollectionIndexes& collectionIndexes) { if (_globalIndexesInfo) { - _globalIndexesInfo->add(index, indexVersion); + _globalIndexesInfo->add(index, collectionIndexes); } else { IndexCatalogTypeMap indexMap; indexMap.emplace(index.getName(), index); - _globalIndexesInfo.emplace(indexVersion, std::move(indexMap)); + _globalIndexesInfo.emplace(collectionIndexes, std::move(indexMap)); } } void CollectionShardingRuntime::removeIndex(OperationContext* opCtx, const std::string& name, - const Timestamp& indexVersion) { + const CollectionIndexes& collectionIndexes) { tassert( 7019500, "Index information does not exist on CSR", _globalIndexesInfo.is_initialized()); - _globalIndexesInfo->remove(name, indexVersion); + _globalIndexesInfo->remove(name, collectionIndexes); } void CollectionShardingRuntime::clearIndexes(OperationContext* opCtx) { diff --git a/src/mongo/db/s/collection_sharding_runtime.h b/src/mongo/db/s/collection_sharding_runtime.h index 44e46d0bc4a..786425d8e92 100644 --- a/src/mongo/db/s/collection_sharding_runtime.h +++ b/src/mongo/db/s/collection_sharding_runtime.h @@ -229,7 +229,7 @@ public: /** * Gets an index version under a lock. */ - boost::optional<Timestamp> getIndexVersion(OperationContext* opCtx); + boost::optional<CollectionIndexes> getCollectionIndexes(OperationContext* opCtx); /** * Gets the index list under a lock. @@ -241,17 +241,17 @@ public: */ void addIndex(OperationContext* opCtx, const IndexCatalogType& index, - const Timestamp& indexVersion); + const CollectionIndexes& collectionIndexes); /** * Removes an index from the shard-role index info under a lock. */ void removeIndex(OperationContext* opCtx, const std::string& name, - const Timestamp& indexVersion); + const CollectionIndexes& collectionIndexes); /** - * Clears the shard-role index info and set the indexVersion to boost::none. + * Clears the shard-role index info and set the collectionIndexes to boost::none. */ void clearIndexes(OperationContext* opCtx); diff --git a/src/mongo/db/s/collection_sharding_runtime_test.cpp b/src/mongo/db/s/collection_sharding_runtime_test.cpp index 6e792a8db0e..334225663e3 100644 --- a/src/mongo/db/s/collection_sharding_runtime_test.cpp +++ b/src/mongo/db/s/collection_sharding_runtime_test.cpp @@ -377,8 +377,10 @@ public: AutoGetCollection autoColl(operationContext(), kTestNss, MODE_IX); _uuid = autoColl.getCollection()->uuid(); - RangeDeleterService::get(operationContext())->onStepUpComplete(operationContext(), 0L); - RangeDeleterService::get(operationContext())->_waitForRangeDeleterServiceUp_FOR_TESTING(); + auto opCtx = operationContext(); + RangeDeleterService::get(opCtx)->onStartup(opCtx); + RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L); + RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING(); } void tearDown() override { @@ -611,8 +613,7 @@ TEST_F(CollectionShardingRuntimeWithCatalogTest, TestGlobalIndexesCache) { opCtx, kTestNss, "x_1", BSON("x" << 1), BSONObj(), uuid(), indexVersion, boost::none); ASSERT_EQ(true, csr()->getIndexes(opCtx).is_initialized()); - ASSERT_EQ(indexVersion, *csr()->getIndexes(opCtx)->getVersion()); - ASSERT_EQ(indexVersion, *csr()->getIndexVersion(opCtx)); + ASSERT_EQ(CollectionIndexes(uuid(), indexVersion), *csr()->getCollectionIndexes(opCtx)); } } // namespace } // namespace mongo diff --git a/src/mongo/db/s/config/sharding_catalog_manager.h b/src/mongo/db/s/config/sharding_catalog_manager.h index 4e8a64f59a6..fba732dce7a 100644 --- a/src/mongo/db/s/config/sharding_catalog_manager.h +++ b/src/mongo/db/s/config/sharding_catalog_manager.h @@ -359,7 +359,8 @@ public: */ void splitOrMarkJumbo(OperationContext* opCtx, const NamespaceString& nss, - const BSONObj& minKey); + const BSONObj& minKey, + boost::optional<int64_t> optMaxChunkSizeBytes); /** * In a transaction, sets the 'allowMigrations' to the requested state and bumps the collection diff --git a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp index 2b0692acb5d..93ac3b56099 100644 --- a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp +++ b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp @@ -1775,19 +1775,31 @@ void ShardingCatalogManager::bumpMultipleCollectionVersionsAndChangeMetadataInTx void ShardingCatalogManager::splitOrMarkJumbo(OperationContext* opCtx, const NamespaceString& nss, - const BSONObj& minKey) { + const BSONObj& minKey, + boost::optional<int64_t> optMaxChunkSizeBytes) { const auto cm = uassertStatusOK( Grid::get(opCtx)->catalogCache()->getShardedCollectionRoutingInfoWithRefresh(opCtx, nss)); auto chunk = cm.findIntersectingChunkWithSimpleCollation(minKey); try { - const auto splitPoints = uassertStatusOK(shardutil::selectChunkSplitPoints( - opCtx, - chunk.getShardId(), - nss, - cm.getShardKeyPattern(), - ChunkRange(chunk.getMin(), chunk.getMax()), - Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes())); + const auto maxChunkSizeBytes = [&]() -> int64_t { + if (optMaxChunkSizeBytes.has_value()) { + return *optMaxChunkSizeBytes; + } + + auto coll = Grid::get(opCtx)->catalogClient()->getCollection( + opCtx, nss, repl::ReadConcernLevel::kMajorityReadConcern); + return coll.getMaxChunkSizeBytes().value_or( + Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes()); + }(); + + const auto splitPoints = uassertStatusOK( + shardutil::selectChunkSplitPoints(opCtx, + chunk.getShardId(), + nss, + cm.getShardKeyPattern(), + ChunkRange(chunk.getMin(), chunk.getMax()), + maxChunkSizeBytes)); if (splitPoints.empty()) { LOGV2(21873, diff --git a/src/mongo/db/s/query_analysis_op_observer.cpp b/src/mongo/db/s/query_analysis_op_observer.cpp index 84606c7f4c7..658c50c992d 100644 --- a/src/mongo/db/s/query_analysis_op_observer.cpp +++ b/src/mongo/db/s/query_analysis_op_observer.cpp @@ -31,6 +31,7 @@ #include "mongo/db/s/query_analysis_coordinator.h" #include "mongo/db/s/query_analysis_op_observer.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/logv2/log.h" #include "mongo/s/analyze_shard_key_util.h" #include "mongo/s/catalog/type_mongos.h" @@ -88,6 +89,17 @@ void QueryAnalysisOpObserver::onUpdate(OperationContext* opCtx, const OplogUpdat }); } } + + if (analyze_shard_key::supportsPersistingSampledQueries() && args.updateArgs->sampleId && + args.updateArgs->preImageDoc && opCtx->writesAreReplicated()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addDiff(*args.updateArgs->sampleId, + args.coll->ns(), + args.coll->uuid(), + *args.updateArgs->preImageDoc, + args.updateArgs->updatedDoc) + .getAsync([](auto) {}); + } } void QueryAnalysisOpObserver::aboutToDelete(OperationContext* opCtx, diff --git a/src/mongo/db/s/query_analysis_writer.cpp b/src/mongo/db/s/query_analysis_writer.cpp new file mode 100644 index 00000000000..19b942e1775 --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer.cpp @@ -0,0 +1,695 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ +#include "mongo/platform/basic.h" + +#include "mongo/db/s/query_analysis_writer.h" + +#include "mongo/client/connpool.h" +#include "mongo/db/catalog/collection_catalog.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/ops/write_ops.h" +#include "mongo/db/server_options.h" +#include "mongo/db/update/document_diff_calculator.h" +#include "mongo/executor/network_interface_factory.h" +#include "mongo/executor/thread_pool_task_executor.h" +#include "mongo/logv2/log.h" +#include "mongo/s/analyze_shard_key_documents_gen.h" +#include "mongo/s/analyze_shard_key_server_parameters_gen.h" +#include "mongo/s/write_ops/batched_command_response.h" +#include "mongo/util/concurrency/thread_pool.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault + +namespace mongo { +namespace analyze_shard_key { + +namespace { + +MONGO_FAIL_POINT_DEFINE(disableQueryAnalysisWriter); +MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingLocally); +MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingRemotely); + +const auto getQueryAnalysisWriter = ServiceContext::declareDecoration<QueryAnalysisWriter>(); + +constexpr int kMaxRetriesOnRetryableErrors = 5; +const WriteConcernOptions kMajorityWriteConcern{WriteConcernOptions::kMajority, + WriteConcernOptions::SyncMode::UNSET, + WriteConcernOptions::kWriteConcernTimeoutSystem}; + +// The size limit for the documents to an insert in a single batch. Leave some padding for other +// fields in the insert command. +constexpr int kMaxBSONObjSizeForInsert = BSONObjMaxUserSize - 500 * 1024; + +/* + * Returns true if this mongod can accept writes to the given collection. Unless the collection is + * in the "local" database, this will only return true if this mongod is a primary (or a + * standalone). + */ +bool canAcceptWrites(OperationContext* opCtx, const NamespaceString& ns) { + ShouldNotConflictWithSecondaryBatchApplicationBlock noPBWMBlock(opCtx->lockState()); + Lock::DBLock lk(opCtx, ns.dbName(), MODE_IS); + Lock::CollectionLock lock(opCtx, ns, MODE_IS); + return mongo::repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(opCtx, + ns.db()); +} + +/* + * Runs the given write command against the given database locally, asserts that the top-level + * command is OK, then asserts the write status using the given 'uassertWriteStatusCb' callback. + * Returns the command response. + */ +BSONObj executeWriteCommandLocal(OperationContext* opCtx, + const std::string dbName, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + DBDirectClient client(opCtx); + BSONObj resObj; + + if (!client.runCommand(dbName, cmdObj, resObj)) { + uassertStatusOK(getStatusFromCommandResult(resObj)); + } + uassertWriteStatusCb(resObj); + + return resObj; +} + +/* + * Runs the given write command against the given database on the (remote) primary, asserts that the + * top-level command is OK, then asserts the write status using the given 'uassertWriteStatusCb' + * callback. Throws a PrimarySteppedDown error if no primary is found. Returns the command response. + */ +BSONObj executeWriteCommandRemote(OperationContext* opCtx, + const std::string dbName, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + auto hostAndPort = repl::ReplicationCoordinator::get(opCtx)->getCurrentPrimaryHostAndPort(); + + if (hostAndPort.empty()) { + uasserted(ErrorCodes::PrimarySteppedDown, "No primary exists currently"); + } + + auto conn = std::make_unique<ScopedDbConnection>(hostAndPort.toString()); + + if (auth::isInternalAuthSet()) { + uassertStatusOK(conn->get()->authenticateInternalUser()); + } + + DBClientBase* client = conn->get(); + ScopeGuard guard([&] { conn->done(); }); + try { + BSONObj resObj; + + if (!client->runCommand(dbName, cmdObj, resObj)) { + uassertStatusOK(getStatusFromCommandResult(resObj)); + } + uassertWriteStatusCb(resObj); + + return resObj; + } catch (...) { + guard.dismiss(); + conn->kill(); + throw; + } +} + +/* + * Runs the given write command against the given collection. If this mongod is currently the + * primary, runs the write command locally. Otherwise, runs the command on the remote primary. + * Internally asserts that the top-level command is OK, then asserts the write status using the + * given 'uassertWriteStatusCb' callback. Internally retries the write command on retryable errors + * (for kMaxRetriesOnRetryableErrors times) so the writes must be idempotent. Returns the + * command response. + */ +BSONObj executeWriteCommand(OperationContext* opCtx, + const NamespaceString& ns, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + const auto dbName = ns.db().toString(); + auto numRetries = 0; + + while (true) { + try { + if (canAcceptWrites(opCtx, ns)) { + // There is a window here where this mongod may step down after check above. In this + // case, a NotWritablePrimary error would be thrown. However, this is preferable to + // running the command while holding locks. + hangQueryAnalysisWriterBeforeWritingLocally.pauseWhileSet(opCtx); + return executeWriteCommandLocal(opCtx, dbName, cmdObj, uassertWriteStatusCb); + } + + hangQueryAnalysisWriterBeforeWritingRemotely.pauseWhileSet(opCtx); + return executeWriteCommandRemote(opCtx, dbName, cmdObj, uassertWriteStatusCb); + } catch (DBException& ex) { + if (ErrorCodes::isRetriableError(ex) && numRetries < kMaxRetriesOnRetryableErrors) { + numRetries++; + continue; + } + throw; + } + } + + return {}; +} + +struct SampledWriteCommandRequest { + UUID sampleId; + NamespaceString nss; + BSONObj cmd; // the BSON for a {Update,Delete,FindAndModify}CommandRequest +}; + +/* + * Returns a sampled update command for the update at 'opIndex' in the given update command. + */ +SampledWriteCommandRequest makeSampledUpdateCommandRequest( + const write_ops::UpdateCommandRequest& originalCmd, int opIndex) { + auto op = originalCmd.getUpdates()[opIndex]; + auto sampleId = op.getSampleId(); + invariant(sampleId); + + write_ops::UpdateCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)}); + sampledCmd.setLet(originalCmd.getLet()); + + return {*sampleId, + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +/* + * Returns a sampled delete command for the delete at 'opIndex' in the given delete command. + */ +SampledWriteCommandRequest makeSampledDeleteCommandRequest( + const write_ops::DeleteCommandRequest& originalCmd, int opIndex) { + auto op = originalCmd.getDeletes()[opIndex]; + auto sampleId = op.getSampleId(); + invariant(sampleId); + + write_ops::DeleteCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)}); + sampledCmd.setLet(originalCmd.getLet()); + + return {*sampleId, + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +/* + * Returns a sampled findAndModify command for the given findAndModify command. + */ +SampledWriteCommandRequest makeSampledFindAndModifyCommandRequest( + const write_ops::FindAndModifyCommandRequest& originalCmd) { + invariant(originalCmd.getSampleId()); + + write_ops::FindAndModifyCommandRequest sampledCmd(originalCmd.getNamespace()); + sampledCmd.setQuery(originalCmd.getQuery()); + sampledCmd.setUpdate(originalCmd.getUpdate()); + sampledCmd.setRemove(originalCmd.getRemove()); + sampledCmd.setUpsert(originalCmd.getUpsert()); + sampledCmd.setNew(originalCmd.getNew()); + sampledCmd.setSort(originalCmd.getSort()); + sampledCmd.setCollation(originalCmd.getCollation()); + sampledCmd.setArrayFilters(originalCmd.getArrayFilters()); + sampledCmd.setLet(originalCmd.getLet()); + sampledCmd.setSampleId(originalCmd.getSampleId()); + + return {*sampledCmd.getSampleId(), + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +} // namespace + +QueryAnalysisWriter& QueryAnalysisWriter::get(OperationContext* opCtx) { + return get(opCtx->getServiceContext()); +} + +QueryAnalysisWriter& QueryAnalysisWriter::get(ServiceContext* serviceContext) { + invariant(analyze_shard_key::isFeatureFlagEnabledIgnoreFCV(), + "Only support analyzing queries when the feature flag is enabled"); + invariant(serverGlobalParams.clusterRole == ClusterRole::ShardServer, + "Only support analyzing queries on a sharded cluster"); + return getQueryAnalysisWriter(serviceContext); +} + +void QueryAnalysisWriter::onStartup() { + auto serviceContext = getQueryAnalysisWriter.owner(this); + auto periodicRunner = serviceContext->getPeriodicRunner(); + invariant(periodicRunner); + + stdx::lock_guard<Latch> lk(_mutex); + + PeriodicRunner::PeriodicJob QueryWriterJob( + "QueryAnalysisQueryWriter", + [this](Client* client) { + if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) { + return; + } + auto opCtx = client->makeOperationContext(); + _flushQueries(opCtx.get()); + }, + Seconds(gQueryAnalysisWriterIntervalSecs)); + _periodicQueryWriter = periodicRunner->makeJob(std::move(QueryWriterJob)); + _periodicQueryWriter.start(); + + PeriodicRunner::PeriodicJob diffWriterJob( + "QueryAnalysisDiffWriter", + [this](Client* client) { + if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) { + return; + } + auto opCtx = client->makeOperationContext(); + _flushDiffs(opCtx.get()); + }, + Seconds(gQueryAnalysisWriterIntervalSecs)); + _periodicDiffWriter = periodicRunner->makeJob(std::move(diffWriterJob)); + _periodicDiffWriter.start(); + + ThreadPool::Options threadPoolOptions; + threadPoolOptions.maxThreads = gQueryAnalysisWriterMaxThreadPoolSize; + threadPoolOptions.minThreads = gQueryAnalysisWriterMinThreadPoolSize; + threadPoolOptions.threadNamePrefix = "QueryAnalysisWriter-"; + threadPoolOptions.poolName = "QueryAnalysisWriterThreadPool"; + threadPoolOptions.onCreateThread = [](const std::string& threadName) { + Client::initThread(threadName.c_str()); + }; + _executor = std::make_shared<executor::ThreadPoolTaskExecutor>( + std::make_unique<ThreadPool>(threadPoolOptions), + executor::makeNetworkInterface("QueryAnalysisWriterNetwork")); + _executor->startup(); +} + +void QueryAnalysisWriter::onShutdown() { + if (_executor) { + _executor->shutdown(); + _executor->join(); + } + if (_periodicQueryWriter.isValid()) { + _periodicQueryWriter.stop(); + } + if (_periodicDiffWriter.isValid()) { + _periodicDiffWriter.stop(); + } +} + +void QueryAnalysisWriter::_flushQueries(OperationContext* opCtx) { + try { + _flush(opCtx, NamespaceString::kConfigSampledQueriesNamespace, &_queries); + } catch (DBException& ex) { + LOGV2(7047300, + "Failed to flush queries, will try again at the next interval", + "error"_attr = redact(ex)); + } +} + +void QueryAnalysisWriter::_flushDiffs(OperationContext* opCtx) { + try { + _flush(opCtx, NamespaceString::kConfigSampledQueriesDiffNamespace, &_diffs); + } catch (DBException& ex) { + LOGV2(7075400, + "Failed to flush diffs, will try again at the next interval", + "error"_attr = redact(ex)); + } +} + +void QueryAnalysisWriter::_flush(OperationContext* opCtx, + const NamespaceString& ns, + Buffer* buffer) { + Buffer tmpBuffer; + // The indices of invalid documents, e.g. documents that fail to insert with DuplicateKey errors + // (i.e. duplicates) and BadValue errors. Such documents should not get added back to the buffer + // when the inserts below fail. + std::set<int> invalid; + { + stdx::lock_guard<Latch> lk(_mutex); + if (buffer->isEmpty()) { + return; + } + std::swap(tmpBuffer, *buffer); + } + ScopeGuard backSwapper([&] { + stdx::lock_guard<Latch> lk(_mutex); + for (int i = 0; i < tmpBuffer.getCount(); i++) { + if (invalid.find(i) == invalid.end()) { + buffer->add(tmpBuffer.at(i)); + } + } + }); + + // Insert the documents in batches from the back of the buffer so that we don't need to move the + // documents forward after each batch. + size_t baseIndex = tmpBuffer.getCount() - 1; + size_t maxBatchSize = gQueryAnalysisWriterMaxBatchSize.load(); + + while (!tmpBuffer.isEmpty()) { + std::vector<BSONObj> docsToInsert; + long long objSize = 0; + + size_t lastIndex = tmpBuffer.getCount(); // inclusive + while (lastIndex > 0 && docsToInsert.size() < maxBatchSize) { + // Check if the next document can fit in the batch. + auto doc = tmpBuffer.at(lastIndex - 1); + if (doc.objsize() + objSize >= kMaxBSONObjSizeForInsert) { + break; + } + lastIndex--; + objSize += doc.objsize(); + docsToInsert.push_back(std::move(doc)); + } + // We don't add a document that is above the size limit to the buffer so we should have + // added at least one document to 'docsToInsert'. + invariant(!docsToInsert.empty()); + + write_ops::InsertCommandRequest insertCmd(ns); + insertCmd.setDocuments(std::move(docsToInsert)); + insertCmd.setWriteCommandRequestBase([&] { + write_ops::WriteCommandRequestBase wcb; + wcb.setOrdered(false); + wcb.setBypassDocumentValidation(false); + return wcb; + }()); + auto insertCmdBson = insertCmd.toBSON( + {BSON(WriteConcernOptions::kWriteConcernField << kMajorityWriteConcern.toBSON())}); + + executeWriteCommand(opCtx, ns, std::move(insertCmdBson), [&](const BSONObj& resObj) { + BatchedCommandResponse response; + std::string errMsg; + + if (!response.parseBSON(resObj, &errMsg)) { + uasserted(ErrorCodes::FailedToParse, errMsg); + } + + if (response.isErrDetailsSet() && response.sizeErrDetails() > 0) { + boost::optional<write_ops::WriteError> firstWriteErr; + + for (const auto& err : response.getErrDetails()) { + if (err.getStatus() == ErrorCodes::DuplicateKey || + err.getStatus() == ErrorCodes::BadValue) { + LOGV2(7075402, + "Ignoring insert error", + "error"_attr = redact(err.getStatus())); + invalid.insert(baseIndex - err.getIndex()); + continue; + } + if (!firstWriteErr) { + // Save the error for later. Go through the rest of the errors to see if + // there are any invalid documents so they can be discarded from the buffer. + firstWriteErr.emplace(err); + } + } + if (firstWriteErr) { + uassertStatusOK(firstWriteErr->getStatus()); + } + } else { + uassertStatusOK(response.toStatus()); + } + }); + + tmpBuffer.truncate(lastIndex, objSize); + baseIndex -= lastIndex; + } + + backSwapper.dismiss(); +} + +void QueryAnalysisWriter::Buffer::add(BSONObj doc) { + if (doc.objsize() > kMaxBSONObjSizeForInsert) { + return; + } + _docs.push_back(std::move(doc)); + _numBytes += _docs.back().objsize(); +} + +void QueryAnalysisWriter::Buffer::truncate(size_t index, long long numBytes) { + invariant(index >= 0); + invariant(index < _docs.size()); + invariant(numBytes > 0); + invariant(numBytes <= _numBytes); + _docs.resize(index); + _numBytes -= numBytes; +} + +bool QueryAnalysisWriter::_exceedsMaxSizeBytes() { + stdx::lock_guard<Latch> lk(_mutex); + return _queries.getSize() + _diffs.getSize() >= gQueryAnalysisWriterMaxMemoryUsageBytes.load(); +} + +ExecutorFuture<void> QueryAnalysisWriter::addFindQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kFind, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addCountQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kCount, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDistinctQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kDistinct, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addAggregateQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kAggregate, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::_addReadQuery(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation) { + invariant(_executor); + return ExecutorFuture<void>(_executor) + .then([this, + sampleId, + nss, + cmdName, + filter = filter.getOwned(), + collation = collation.getOwned()] { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, nss); + + if (!collUuid) { + LOGV2(7047301, "Found a sampled read query for non-existing collection"); + return; + } + + auto cmd = SampledReadCommand{filter.getOwned(), collation.getOwned()}; + auto doc = SampledReadQueryDocument{sampleId, nss, *collUuid, cmdName, cmd.toBSON()}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss](Status status) { + LOGV2(7047302, + "Failed to add read query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addUpdateQuery( + const write_ops::UpdateCommandRequest& updateCmd, int opIndex) { + invariant(updateCmd.getUpdates()[opIndex].getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, sampledUpdateCmd = makeSampledUpdateCommandRequest(updateCmd, opIndex)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledUpdateCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075300, + "Found a sampled update query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledUpdateCmd.sampleId, + sampledUpdateCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kUpdate, + std::move(sampledUpdateCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = updateCmd.getNamespace()](Status status) { + LOGV2(7075301, + "Failed to add update query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDeleteQuery( + const write_ops::DeleteCommandRequest& deleteCmd, int opIndex) { + invariant(deleteCmd.getDeletes()[opIndex].getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, sampledDeleteCmd = makeSampledDeleteCommandRequest(deleteCmd, opIndex)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledDeleteCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075302, + "Found a sampled delete query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledDeleteCmd.sampleId, + sampledDeleteCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kDelete, + std::move(sampledDeleteCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = deleteCmd.getNamespace()](Status status) { + LOGV2(7075303, + "Failed to add delete query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addFindAndModifyQuery( + const write_ops::FindAndModifyCommandRequest& findAndModifyCmd) { + invariant(findAndModifyCmd.getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, + sampledFindAndModifyCmd = + makeSampledFindAndModifyCommandRequest(findAndModifyCmd)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledFindAndModifyCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075304, + "Found a sampled findAndModify query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledFindAndModifyCmd.sampleId, + sampledFindAndModifyCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kFindAndModify, + std::move(sampledFindAndModifyCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = findAndModifyCmd.getNamespace()](Status status) { + LOGV2(7075305, + "Failed to add findAndModify query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDiff(const UUID& sampleId, + const NamespaceString& nss, + const UUID& collUuid, + const BSONObj& preImage, + const BSONObj& postImage) { + invariant(_executor); + return ExecutorFuture<void>(_executor) + .then([this, + sampleId, + nss, + collUuid, + preImage = preImage.getOwned(), + postImage = postImage.getOwned()]() { + auto diff = doc_diff::computeInlineDiff(preImage, postImage); + + if (!diff || diff->isEmpty()) { + return; + } + + auto doc = SampledQueryDiffDocument{sampleId, nss, collUuid, std::move(*diff)}; + stdx::lock_guard<Latch> lk(_mutex); + _diffs.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushDiffs(opCtx); + } + }) + .onError([this, nss](Status status) { + LOGV2(7075401, "Failed to add diff", "ns"_attr = nss, "error"_attr = redact(status)); + }); +} + +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/query_analysis_writer.h b/src/mongo/db/s/query_analysis_writer.h new file mode 100644 index 00000000000..508d4903b65 --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer.h @@ -0,0 +1,214 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/operation_context.h" +#include "mongo/db/service_context.h" +#include "mongo/executor/task_executor.h" +#include "mongo/s/analyze_shard_key_common_gen.h" +#include "mongo/s/analyze_shard_key_util.h" +#include "mongo/s/write_ops/batched_command_request.h" +#include "mongo/util/periodic_runner.h" + +namespace mongo { +namespace analyze_shard_key { + +/** + * Owns the machinery for persisting sampled queries. That consists of the following: + * - The buffer that stores sampled queries and the periodic background job that inserts those + * queries into the local config.sampledQueries collection. + * - The buffer that stores diffs for sampled update queries and the periodic background job that + * inserts those diffs into the local config.sampledQueriesDiff collection. + * + * Currently, query sampling is only supported on a sharded cluster. So a writer must be a shardsvr + * mongod. If the mongod is a primary, it will execute the insert commands locally. If it is a + * secondary, it will perform the insert commands against the primary. + * + * The memory usage of the buffers is controlled by the 'queryAnalysisWriterMaxMemoryUsageBytes' + * server parameter. Upon adding a query or a diff that causes the total size of buffers to exceed + * the limit, the writer will flush the corresponding buffer immediately instead of waiting for it + * to get flushed later by the periodic job. + */ +class QueryAnalysisWriter final : public std::enable_shared_from_this<QueryAnalysisWriter> { + QueryAnalysisWriter(const QueryAnalysisWriter&) = delete; + QueryAnalysisWriter& operator=(const QueryAnalysisWriter&) = delete; + +public: + /** + * Temporarily stores documents to be written to disk. + */ + struct Buffer { + public: + /** + * Adds the given document to the buffer if its size is below the limit (i.e. + * BSONObjMaxUserSize - some padding) and increments the total number of bytes accordingly. + */ + void add(BSONObj doc); + + /** + * Removes the documents at 'index' onwards from the buffer and decrements the total number + * of the bytes by 'numBytes'. The caller must ensure that that 'numBytes' is indeed the + * total size of the documents being removed. + */ + void truncate(size_t index, long long numBytes); + + bool isEmpty() const { + return _docs.empty(); + } + + int getCount() const { + return _docs.size(); + } + + long long getSize() const { + return _numBytes; + } + + BSONObj at(size_t index) const { + return _docs[index]; + } + + private: + std::vector<BSONObj> _docs; + long long _numBytes = 0; + }; + + QueryAnalysisWriter() = default; + ~QueryAnalysisWriter() = default; + + QueryAnalysisWriter(QueryAnalysisWriter&& source) = delete; + QueryAnalysisWriter& operator=(QueryAnalysisWriter&& other) = delete; + + /** + * Obtains the service-wide QueryAnalysisWriter instance. + */ + static QueryAnalysisWriter& get(OperationContext* opCtx); + static QueryAnalysisWriter& get(ServiceContext* serviceContext); + + void onStartup(); + + void onShutdown(); + + ExecutorFuture<void> addFindQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addCountQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addDistinctQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addAggregateQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addUpdateQuery(const write_ops::UpdateCommandRequest& updateCmd, + int opIndex); + ExecutorFuture<void> addDeleteQuery(const write_ops::DeleteCommandRequest& deleteCmd, + int opIndex); + + ExecutorFuture<void> addFindAndModifyQuery( + const write_ops::FindAndModifyCommandRequest& findAndModifyCmd); + + ExecutorFuture<void> addDiff(const UUID& sampleId, + const NamespaceString& nss, + const UUID& collUuid, + const BSONObj& preImage, + const BSONObj& postImage); + + int getQueriesCountForTest() const { + stdx::lock_guard<Latch> lk(_mutex); + return _queries.getCount(); + } + + void flushQueriesForTest(OperationContext* opCtx) { + _flushQueries(opCtx); + } + + int getDiffsCountForTest() const { + stdx::lock_guard<Latch> lk(_mutex); + return _diffs.getCount(); + } + + void flushDiffsForTest(OperationContext* opCtx) { + _flushDiffs(opCtx); + } + +private: + ExecutorFuture<void> _addReadQuery(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation); + + void _flushQueries(OperationContext* opCtx); + void _flushDiffs(OperationContext* opCtx); + + /** + * The helper for '_flushQueries' and '_flushDiffs'. Inserts the documents in 'buffer' into the + * collection 'ns' in batches, and removes all the inserted documents from 'buffer'. Internally + * retries the inserts on retryable errors for a fixed number of times. Ignores DuplicateKey + * errors since they are expected for the following reasons: + * - For the query buffer, a sampled query that is idempotent (e.g. a read or retryable write) + * could get added to the buffer (across nodes) more than once due to retries. + * - For the diff buffer, a sampled multi-update query could end up generating multiple diffs + * and each diff is identified using the sample id of the sampled query that creates it. + * + * Throws an error if the inserts fail with any other error. + */ + void _flush(OperationContext* opCtx, const NamespaceString& nss, Buffer* buffer); + + /** + * Returns true if the total size of the buffered queries and diffs has exceeded the maximum + * amount of memory that the writer is allowed to use. + */ + bool _exceedsMaxSizeBytes(); + + mutable Mutex _mutex = MONGO_MAKE_LATCH("QueryAnalysisWriter::_mutex"); + + PeriodicJobAnchor _periodicQueryWriter; + Buffer _queries; + + PeriodicJobAnchor _periodicDiffWriter; + Buffer _diffs; + + // Initialized on startup and joined on shutdown. + std::shared_ptr<executor::TaskExecutor> _executor; +}; + +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/query_analysis_writer_test.cpp b/src/mongo/db/s/query_analysis_writer_test.cpp new file mode 100644 index 00000000000..b17f39c82df --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer_test.cpp @@ -0,0 +1,1281 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/s/query_analysis_writer.h" + +#include "mongo/bson/unordered_fields_bsonobj_comparator.h" +#include "mongo/db/db_raii.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/s/shard_server_test_fixture.h" +#include "mongo/db/update/document_diff_calculator.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/logv2/log.h" +#include "mongo/s/analyze_shard_key_documents_gen.h" +#include "mongo/unittest/death_test.h" +#include "mongo/util/fail_point.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest + +namespace mongo { +namespace analyze_shard_key { +namespace { + +TEST(QueryAnalysisWriterBufferTest, AddBasic) { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc0 = BSON("a" << 0); + buffer.add(doc0); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc0.objsize()); + + auto doc1 = BSON("a" << BSON_ARRAY(0 << 1 << 2)); + buffer.add(doc1); + ASSERT_EQ(buffer.getCount(), 2); + ASSERT_EQ(buffer.getSize(), doc0.objsize() + doc1.objsize()); + + ASSERT_BSONOBJ_EQ(buffer.at(0), doc0); + ASSERT_BSONOBJ_EQ(buffer.at(1), doc1); +} + +TEST(QueryAnalysisWriterBufferTest, AddTooLarge) { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON(std::string(BSONObjMaxUserSize, 'a') << 1); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 0); + ASSERT_EQ(buffer.getSize(), 0); +} + +TEST(QueryAnalysisWriterBufferTest, TruncateBasic) { + auto testTruncateCommon = [](int oldCount, int newCount) { + auto buffer = QueryAnalysisWriter::Buffer(); + + std::vector<BSONObj> docs; + for (auto i = 0; i < oldCount; i++) { + docs.push_back(BSON("a" << i)); + } + // The documents have the same size. + auto docSize = docs.back().objsize(); + + for (const auto& doc : docs) { + buffer.add(doc); + } + ASSERT_EQ(buffer.getCount(), oldCount); + ASSERT_EQ(buffer.getSize(), oldCount * docSize); + + buffer.truncate(newCount, (oldCount - newCount) * docSize); + ASSERT_EQ(buffer.getCount(), newCount); + ASSERT_EQ(buffer.getSize(), newCount * docSize); + for (auto i = 0; i < newCount; i++) { + ASSERT_BSONOBJ_EQ(buffer.at(i), docs[i]); + } + }; + + testTruncateCommon(10 /* oldCount */, 6 /* newCount */); + testTruncateCommon(10 /* oldCount */, 0 /* newCount */); // Truncate all. + testTruncateCommon(10 /* oldCount */, 9 /* newCount */); // Truncate one. +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Negative, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(-1, doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Positive, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(2, doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Negative, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, -doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Zero, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, 0); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Positive, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, doc.objsize() * 2); +} + +void assertBsonObjEqualUnordered(const BSONObj& lhs, const BSONObj& rhs) { + UnorderedFieldsBSONObjComparator comparator; + ASSERT_EQ(comparator.compare(lhs, rhs), 0); +} + +struct QueryAnalysisWriterTest : public ShardServerTestFixture { +public: + void setUp() { + ShardServerTestFixture::setUp(); + QueryAnalysisWriter::get(operationContext()).onStartup(); + + DBDirectClient client(operationContext()); + client.createCollection(nss0.toString()); + client.createCollection(nss1.toString()); + } + + void tearDown() { + QueryAnalysisWriter::get(operationContext()).onShutdown(); + ShardServerTestFixture::tearDown(); + } + +protected: + UUID getCollectionUUID(const NamespaceString& nss) const { + auto collectionCatalog = CollectionCatalog::get(operationContext()); + return *collectionCatalog->lookupUUIDByNSS(operationContext(), nss); + } + + BSONObj makeNonEmptyFilter() { + return BSON("_id" << UUID::gen()); + } + + BSONObj makeNonEmptyCollation() { + int strength = rand() % 5 + 1; + return BSON("locale" + << "en_US" + << "strength" << strength); + } + + /* + * Makes an UpdateCommandRequest for the collection 'nss' such that the command contains + * 'numOps' updates and the ones whose indices are in 'markForSampling' are marked for sampling. + * Returns the UpdateCommandRequest and the map storing the expected sampled + * UpdateCommandRequests by sample id, if any. + */ + std::pair<write_ops::UpdateCommandRequest, std::map<UUID, write_ops::UpdateCommandRequest>> + makeUpdateCommandRequest(const NamespaceString& nss, + int numOps, + std::set<int> markForSampling, + std::string filterFieldName = "a") { + write_ops::UpdateCommandRequest originalCmd(nss); + std::vector<write_ops::UpdateOpEntry> updateOps; // populated below. + originalCmd.setLet(let); + originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::UpdateCommandRequest> expectedSampledCmds; + + for (auto i = 0; i < numOps; i++) { + auto updateOp = write_ops::UpdateOpEntry( + BSON(filterFieldName << i), + write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << i)))); + updateOp.setC(BSON("x" << i)); + updateOp.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << i))}); + updateOp.setMulti(_getRandomBool()); + updateOp.setUpsert(_getRandomBool()); + updateOp.setUpsertSupplied(_getRandomBool()); + updateOp.setCollation(makeNonEmptyCollation()); + + if (markForSampling.find(i) != markForSampling.end()) { + auto sampleId = UUID::gen(); + updateOp.setSampleId(sampleId); + + write_ops::UpdateCommandRequest expectedSampledCmd = originalCmd; + expectedSampledCmd.setUpdates({updateOp}); + expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation( + boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + updateOps.push_back(updateOp); + } + originalCmd.setUpdates(updateOps); + + return {originalCmd, expectedSampledCmds}; + } + + /* + * Makes an DeleteCommandRequest for the collection 'nss' such that the command contains + * 'numOps' deletes and the ones whose indices are in 'markForSampling' are marked for sampling. + * Returns the DeleteCommandRequest and the map storing the expected sampled + * DeleteCommandRequests by sample id, if any. + */ + std::pair<write_ops::DeleteCommandRequest, std::map<UUID, write_ops::DeleteCommandRequest>> + makeDeleteCommandRequest(const NamespaceString& nss, + int numOps, + std::set<int> markForSampling, + std::string filterFieldName = "a") { + write_ops::DeleteCommandRequest originalCmd(nss); + std::vector<write_ops::DeleteOpEntry> deleteOps; // populated and set below. + originalCmd.setLet(let); + originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::DeleteCommandRequest> expectedSampledCmds; + + for (auto i = 0; i < numOps; i++) { + auto deleteOp = + write_ops::DeleteOpEntry(BSON(filterFieldName << i), _getRandomBool() /* multi */); + deleteOp.setCollation(makeNonEmptyCollation()); + + if (markForSampling.find(i) != markForSampling.end()) { + auto sampleId = UUID::gen(); + deleteOp.setSampleId(sampleId); + + write_ops::DeleteCommandRequest expectedSampledCmd = originalCmd; + expectedSampledCmd.setDeletes({deleteOp}); + expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation( + boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + deleteOps.push_back(deleteOp); + } + originalCmd.setDeletes(deleteOps); + + return {originalCmd, expectedSampledCmds}; + } + + /* + * Makes a FindAndModifyCommandRequest for the collection 'nss'. The findAndModify is an update + * if 'isUpdate' is true, and a remove otherwise. If 'markForSampling' is true, it is marked for + * sampling. Returns the FindAndModifyCommandRequest and the map storing the expected sampled + * FindAndModifyCommandRequests by sample id, if any. + */ + std::pair<write_ops::FindAndModifyCommandRequest, + std::map<UUID, write_ops::FindAndModifyCommandRequest>> + makeFindAndModifyCommandRequest(const NamespaceString& nss, + bool isUpdate, + bool markForSampling, + std::string filterFieldName = "a") { + write_ops::FindAndModifyCommandRequest originalCmd(nss); + originalCmd.setQuery(BSON(filterFieldName << 0)); + originalCmd.setUpdate( + write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << 0)))); + originalCmd.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << 10))}); + originalCmd.setSort(BSON("_id" << 1)); + if (isUpdate) { + originalCmd.setUpsert(_getRandomBool()); + originalCmd.setNew(_getRandomBool()); + } + originalCmd.setCollation(makeNonEmptyCollation()); + originalCmd.setLet(let); + originalCmd.setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::FindAndModifyCommandRequest> expectedSampledCmds; + if (markForSampling) { + auto sampleId = UUID::gen(); + originalCmd.setSampleId(sampleId); + + auto expectedSampledCmd = originalCmd; + expectedSampledCmd.setEncryptionInformation(boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + + return {originalCmd, expectedSampledCmds}; + } + + void deleteSampledQueryDocuments() const { + DBDirectClient client(operationContext()); + client.remove(NamespaceString::kConfigSampledQueriesNamespace.toString(), BSONObj()); + } + + /** + * Returns the number of the documents for the collection 'nss' in the config.sampledQueries + * collection. + */ + int getSampledQueryDocumentsCount(const NamespaceString& nss) { + return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesNamespace, nss); + } + + /* + * Asserts that there is a sampled read query document with the given sample id and that it has + * the given fields. + */ + void assertSampledReadQueryDocument(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation) { + auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId); + auto parsedQueryDoc = + SampledReadQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedQueryDoc.getNs(), nss); + ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId); + ASSERT(parsedQueryDoc.getCmdName() == cmdName); + auto parsedCmd = SampledReadCommand::parse(IDLParserContext("QueryAnalysisWriterTest"), + parsedQueryDoc.getCmd()); + ASSERT_BSONOBJ_EQ(parsedCmd.getFilter(), filter); + ASSERT_BSONOBJ_EQ(parsedCmd.getCollation(), collation); + } + + /* + * Asserts that there is a sampled write query document with the given sample id and that it has + * the given fields. + */ + template <typename CommandRequestType> + void assertSampledWriteQueryDocument(const UUID& sampleId, + const NamespaceString& nss, + SampledWriteCommandNameEnum cmdName, + const CommandRequestType& expectedCmd) { + auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId); + auto parsedQueryDoc = + SampledWriteQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedQueryDoc.getNs(), nss); + ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId); + ASSERT(parsedQueryDoc.getCmdName() == cmdName); + auto parsedCmd = CommandRequestType::parse(IDLParserContext("QueryAnalysisWriterTest"), + parsedQueryDoc.getCmd()); + ASSERT_BSONOBJ_EQ(parsedCmd.toBSON({}), expectedCmd.toBSON({})); + } + + /* + * Returns the number of the documents for the collection 'nss' in the config.sampledQueriesDiff + * collection. + */ + int getDiffDocumentsCount(const NamespaceString& nss) { + return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesDiffNamespace, nss); + } + + /* + * Asserts that there is a sampled diff document with the given sample id and that it has + * the given fields. + */ + void assertDiffDocument(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& expectedDiff) { + auto doc = + _getConfigDocument(NamespaceString::kConfigSampledQueriesDiffNamespace, sampleId); + auto parsedDiffDoc = + SampledQueryDiffDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedDiffDoc.getNs(), nss); + ASSERT_EQ(parsedDiffDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedDiffDoc.getSampleId(), sampleId); + assertBsonObjEqualUnordered(parsedDiffDoc.getDiff(), expectedDiff); + } + + const NamespaceString nss0{"testDb", "testColl0"}; + const NamespaceString nss1{"testDb", "testColl1"}; + + // Test with both empty and non-empty filter and collation to verify that the + // QueryAnalysisWriter doesn't require filter or collation to be non-empty. + const BSONObj emptyFilter{}; + const BSONObj emptyCollation{}; + + const BSONObj let = BSON("x" << 1); + // Test with EncryptionInformation to verify that QueryAnalysisWriter does not persist the + // WriteCommandRequestBase fields, especially this sensitive field. + const EncryptionInformation encryptionInformation{BSON("foo" + << "bar")}; + +private: + bool _getRandomBool() { + return rand() % 2 == 0; + } + + /** + * Returns the number of the documents for the collection 'collNss' in the config collection + * 'configNss'. + */ + int _getConfigDocumentsCount(const NamespaceString& configNss, + const NamespaceString& collNss) const { + DBDirectClient client(operationContext()); + return client.count(configNss, BSON("ns" << collNss.toString())); + } + + /** + * Returns the document with the given _id in the config collection 'configNss'. + */ + BSONObj _getConfigDocument(const NamespaceString configNss, const UUID& id) const { + DBDirectClient client(operationContext()); + + FindCommandRequest findRequest{configNss}; + findRequest.setFilter(BSON("_id" << id)); + auto cursor = client.find(std::move(findRequest)); + ASSERT(cursor->more()); + return cursor->next(); + } + + RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", true}; + FailPointEnableBlock _fp{"disableQueryAnalysisWriter"}; +}; + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetIfFeatureFlagNotEnabled, "invariant") { + RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", + false}; + QueryAnalysisWriter::get(operationContext()); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnConfigServer, "invariant") { + serverGlobalParams.clusterRole = ClusterRole::ConfigServer; + QueryAnalysisWriter::get(operationContext()); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnNonShardServer, "invariant") { + serverGlobalParams.clusterRole = ClusterRole::None; + QueryAnalysisWriter::get(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, NoQueries) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + writer.flushQueriesForTest(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, FindQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testFindCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addFindQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testFindCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testFindCmdCommon(makeNonEmptyFilter(), emptyCollation); + testFindCmdCommon(emptyFilter, makeNonEmptyCollation()); + testFindCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, CountQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testCountCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addCountQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kCount, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testCountCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testCountCmdCommon(makeNonEmptyFilter(), emptyCollation); + testCountCmdCommon(emptyFilter, makeNonEmptyCollation()); + testCountCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, DistinctQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testDistinctCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addDistinctQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kDistinct, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testDistinctCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testDistinctCmdCommon(makeNonEmptyFilter(), emptyCollation); + testDistinctCmdCommon(emptyFilter, makeNonEmptyCollation()); + testDistinctCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, AggregateQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testAggregateCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testAggregateCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testAggregateCmdCommon(makeNonEmptyFilter(), emptyCollation); + testAggregateCmdCommon(emptyFilter, makeNonEmptyCollation()); + testAggregateCmdCommon(emptyFilter, emptyCollation); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, UpdateQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = makeUpdateCommandRequest(nss0, 1, {} /* markForSampling */); + writer.addUpdateQuery(originalCmd, 0).get(); +} + +TEST_F(QueryAnalysisWriterTest, UpdateQueriesMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeUpdateCommandRequest(nss0, 3, {0, 2} /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addUpdateQuery(originalCmd, 0).get(); + writer.addUpdateQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 2); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledCmd); + } +} + +DEATH_TEST_F(QueryAnalysisWriterTest, DeleteQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = makeDeleteCommandRequest(nss0, 1, {} /* markForSampling */); + writer.addDeleteQuery(originalCmd, 0).get(); +} + +TEST_F(QueryAnalysisWriterTest, DeleteQueriesMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeDeleteCommandRequest(nss0, 3, {1, 2} /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addDeleteQuery(originalCmd, 1).get(); + writer.addDeleteQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 2); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledCmd); + } +} + +DEATH_TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = + makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, false /* markForSampling */); + writer.addFindAndModifyQuery(originalCmd).get(); +} + +TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryUpdateMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, true /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 1U); + auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin(); + + writer.addFindAndModifyQuery(originalCmd).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd); +} + +TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryRemoveMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeFindAndModifyCommandRequest(nss0, false /* isUpdate */, true /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 1U); + auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin(); + + writer.addFindAndModifyQuery(originalCmd).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd); +} + +TEST_F(QueryAnalysisWriterTest, MultipleQueriesAndCollections) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + // Make nss0 have one query. + auto [originalDeleteCmd, expectedSampledDeleteCmds] = + makeDeleteCommandRequest(nss1, 3, {1} /* markForSampling */); + ASSERT_EQ(expectedSampledDeleteCmds.size(), 1U); + auto [deleteSampleId, expectedSampledDeleteCmd] = *expectedSampledDeleteCmds.begin(); + + // Make nss1 have two queries. + auto [originalUpdateCmd, expectedSampledUpdateCmds] = + makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */); + ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U); + auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin(); + + auto countSampleId = UUID::gen(); + auto originalCountFilter = makeNonEmptyFilter(); + auto originalCountCollation = makeNonEmptyCollation(); + + writer.addDeleteQuery(originalDeleteCmd, 1).get(); + writer.addUpdateQuery(originalUpdateCmd, 0).get(); + writer.addCountQuery(countSampleId, nss1, originalCountFilter, originalCountCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 3); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(deleteSampleId, + expectedSampledDeleteCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledDeleteCmd); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 2); + assertSampledWriteQueryDocument(updateSampleId, + expectedSampledUpdateCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledUpdateCmd); + assertSampledReadQueryDocument(countSampleId, + nss1, + SampledReadCommandNameEnum::kCount, + originalCountFilter, + originalCountCollation); +} + +TEST_F(QueryAnalysisWriterTest, DuplicateQueries) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto findSampleId = UUID::gen(); + auto originalFindFilter = makeNonEmptyFilter(); + auto originalFindCollation = makeNonEmptyCollation(); + + auto [originalUpdateCmd, expectedSampledUpdateCmds] = + makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */); + ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U); + auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin(); + + auto countSampleId = UUID::gen(); + auto originalCountFilter = makeNonEmptyFilter(); + auto originalCountCollation = makeNonEmptyCollation(); + + writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument(findSampleId, + nss0, + SampledReadCommandNameEnum::kFind, + originalFindFilter, + originalFindCollation); + + writer.addUpdateQuery(originalUpdateCmd, 0).get(); + writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation) + .get(); // This is a duplicate. + writer.addCountQuery(countSampleId, nss0, originalCountFilter, originalCountCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 3); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 3); + assertSampledWriteQueryDocument(updateSampleId, + expectedSampledUpdateCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledUpdateCmd); + assertSampledReadQueryDocument(findSampleId, + nss0, + SampledReadCommandNameEnum::kFind, + originalFindFilter, + originalFindCollation); + assertSampledReadQueryDocument(countSampleId, + nss0, + SampledReadCommandNameEnum::kCount, + originalCountFilter, + originalCountCollation); +} + +TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBatchSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2}; + auto numQueries = 5; + + std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds; + for (auto i = 0; i < numQueries; i++) { + auto sampleId = UUID::gen(); + auto filter = makeNonEmptyFilter(); + auto collation = makeNonEmptyCollation(); + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + expectedSampledCmds.push_back({sampleId, filter, collation}); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& [sampleId, filter, collation] : expectedSampledCmds) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + } +} + +TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBSONObjSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto numQueries = 3; + std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds; + for (auto i = 0; i < numQueries; i++) { + auto sampleId = UUID::gen(); + auto filter = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1); + auto collation = makeNonEmptyCollation(); + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + expectedSampledCmds.push_back({sampleId, filter, collation}); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& [sampleId, filter, collation] : expectedSampledCmds) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddReadIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto sampleId0 = UUID::gen(); + auto filter0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1); + auto collation0 = makeNonEmptyCollation(); + + auto sampleId1 = UUID::gen(); + auto filter1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1); + auto collation1 = makeNonEmptyCollation(); + + writer.addFindQuery(sampleId0, nss0, filter0, collation0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addAggregateQuery(sampleId1, nss1, filter1, collation1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId0, nss0, SampledReadCommandNameEnum::kFind, filter0, collation0); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1); + assertSampledReadQueryDocument( + sampleId1, nss1, SampledReadCommandNameEnum::kAggregate, filter1, collation1); +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddUpdateIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + auto [originalCmd, expectedSampledCmds] = + makeUpdateCommandRequest(nss0, + 3, + {0, 2} /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addUpdateQuery(originalCmd, 0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addUpdateQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledCmd); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddDeleteIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + auto [originalCmd, expectedSampledCmds] = + makeDeleteCommandRequest(nss0, + 3, + {0, 1} /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addDeleteQuery(originalCmd, 0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addDeleteQuery(originalCmd, 1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledCmd); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddFindAndModifyIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto [originalCmd0, expectedSampledCmds0] = makeFindAndModifyCommandRequest( + nss0, + true /* isUpdate */, + true /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds0.size(), 1U); + auto [sampleId0, expectedSampledCmd0] = *expectedSampledCmds0.begin(); + + auto [originalCmd1, expectedSampledCmds1] = makeFindAndModifyCommandRequest( + nss1, + false /* isUpdate */, + true /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'b') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds0.size(), 1U); + auto [sampleId1, expectedSampledCmd1] = *expectedSampledCmds1.begin(); + + writer.addFindAndModifyQuery(originalCmd0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addFindAndModifyQuery(originalCmd1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId0, + expectedSampledCmd0.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd0); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1); + assertSampledWriteQueryDocument(sampleId1, + expectedSampledCmd1.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd1); +} + +TEST_F(QueryAnalysisWriterTest, AddQueriesBackAfterWriteError) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto originalFilter = makeNonEmptyFilter(); + auto originalCollation = makeNonEmptyCollation(); + auto numQueries = 8; + + std::vector<UUID> sampleIds0; + for (auto i = 0; i < numQueries; i++) { + sampleIds0.push_back(UUID::gen()); + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + + // Force the documents to get inserted in three batches of size 3, 3 and 2, respectively. + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 3}; + + // Hang after inserting the documents in the first batch. + auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts"); + auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0); + + auto future = stdx::async(stdx::launch::async, [&] { + ThreadClient tc(getServiceContext()); + auto opCtx = makeOperationContext(); + writer.flushQueriesForTest(opCtx.get()); + }); + + hangFp->waitForTimesEntered(hangTimesEntered + 1); + // Force the second batch to fail so that it falls back to inserting one document at a time in + // order, and then force the first and second document in the batch to fail. + auto failFp = globalFailPointRegistry().find("failCollectionInserts"); + failFp->setMode(FailPoint::nTimes, 3); + hangFp->setMode(FailPoint::off, 0); + + future.get(); + // Verify that all the documents other than the ones in the first batch got added back to the + // buffer after the error. That is, the error caused the last document in the second batch to + // get added to buffer also although it was successfully inserted since the writer did not have + // a way to tell if the error caused the entire command to fail early. + ASSERT_EQ(writer.getQueriesCountForTest(), 5); + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 4); + + // Flush that remaining documents. If the documents were not added back correctly, some + // documents would be missing and the checks below would fail. + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& sampleId : sampleIds0) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } +} + +TEST_F(QueryAnalysisWriterTest, RemoveDuplicatesFromBufferAfterWriteError) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto originalFilter = makeNonEmptyFilter(); + auto originalCollation = makeNonEmptyCollation(); + + auto numQueries0 = 3; + + std::vector<UUID> sampleIds0; + for (auto i = 0; i < numQueries0; i++) { + sampleIds0.push_back(UUID::gen()); + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries0); + for (const auto& sampleId : sampleIds0) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } + + auto numQueries1 = 5; + + std::vector<UUID> sampleIds1; + for (auto i = 0; i < numQueries1; i++) { + sampleIds1.push_back(UUID::gen()); + writer.addFindQuery(sampleIds1[i], nss1, originalFilter, originalCollation).get(); + // This is a duplicate. + if (i < numQueries0) { + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0 + numQueries1); + + // Force the batch to fail so that it falls back to inserting one document at a time in order. + auto failFp = globalFailPointRegistry().find("failCollectionInserts"); + failFp->setMode(FailPoint::nTimes, 1); + + // Hang after inserting the first non-duplicate document. + auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts"); + auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0); + + auto future = stdx::async(stdx::launch::async, [&] { + ThreadClient tc(getServiceContext()); + auto opCtx = makeOperationContext(); + writer.flushQueriesForTest(opCtx.get()); + }); + + hangFp->waitForTimesEntered(hangTimesEntered + 1); + // Force the next non-duplicate document to fail to insert. + failFp->setMode(FailPoint::nTimes, 1); + hangFp->setMode(FailPoint::off, 0); + + future.get(); + // Verify that the duplicate documents did not get added back to the buffer after the error. + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries1); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1 - 1); + + // Flush that remaining documents. If the documents were not added back correctly, the document + // that previously failed to insert would be missing and the checks below would fail. + failFp->setMode(FailPoint::off, 0); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1); + for (const auto& sampleId : sampleIds1) { + assertSampledReadQueryDocument( + sampleId, nss1, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } +} + +TEST_F(QueryAnalysisWriterTest, NoDiffs) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + writer.flushQueriesForTest(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, DiffsBasic) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON("a" << 1); + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId, nss0, *doc_diff::computeInlineDiff(preImage, postImage)); +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleQueriesAndCollections) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + // Make nss0 have a diff for one query. + auto collUuid0 = getCollectionUUID(nss0); + + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0 << "b" << 0 << "c" << 0); + auto postImage0 = BSON("a" << 0 << "b" << 1 << "d" << 1); + + // Make nss1 have diffs for two queries. + auto collUuid1 = getCollectionUUID(nss1); + + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1) << "d" << BSON("e" << 1)); + auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2) << "d" << BSON("e" << 2)); + + auto sampleId2 = UUID::gen(); + auto preImage2 = BSON("a" << BSONObj()); + auto postImage2 = BSON("a" << BSON("b" << 2)); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get(); + writer.addDiff(sampleId2, nss1, collUuid1, preImage2, postImage2).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 3); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + + ASSERT_EQ(getDiffDocumentsCount(nss1), 2); + assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1)); + assertDiffDocument(sampleId2, nss1, *doc_diff::computeInlineDiff(preImage2, postImage2)); +} + +TEST_F(QueryAnalysisWriterTest, DuplicateDiffs) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0); + auto postImage0 = BSON("a" << 1); + + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1)); + auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2)); + + auto sampleId2 = UUID::gen(); + auto preImage2 = BSON("a" << BSONObj()); + auto postImage2 = BSON("a" << BSON("b" << 2)); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + + writer.addDiff(sampleId1, nss0, collUuid0, preImage1, postImage1).get(); + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0) + .get(); // This is a duplicate. + writer.addDiff(sampleId2, nss0, collUuid0, preImage2, postImage2).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 3); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 3); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + assertDiffDocument(sampleId1, nss0, *doc_diff::computeInlineDiff(preImage1, postImage1)); + assertDiffDocument(sampleId2, nss0, *doc_diff::computeInlineDiff(preImage2, postImage2)); +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBatchSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2}; + auto numDiffs = 5; + auto collUuid0 = getCollectionUUID(nss0); + + std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs; + for (auto i = 0; i < numDiffs; i++) { + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON(("a" + std::to_string(i)) << 1); + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + expectedSampledDiffs.push_back( + {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)}); + } + ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs); + for (const auto& [sampleId, diff] : expectedSampledDiffs) { + assertDiffDocument(sampleId, nss0, diff); + } +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBSONObjSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto numDiffs = 3; + auto collUuid0 = getCollectionUUID(nss0); + + std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs; + for (auto i = 0; i < numDiffs; i++) { + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1); + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + expectedSampledDiffs.push_back( + {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)}); + } + ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs); + for (const auto& [sampleId, diff] : expectedSampledDiffs) { + assertDiffDocument(sampleId, nss0, diff); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddDiffIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0); + auto postImage0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1); + + auto collUuid1 = getCollectionUUID(nss1); + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 0); + auto postImage1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + // Adding the next diff causes the size to exceed the limit. + writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + ASSERT_EQ(getDiffDocumentsCount(nss1), 1); + assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1)); +} + +TEST_F(QueryAnalysisWriterTest, DiffEmpty) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 1); + auto postImage = preImage; + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 0); +} + +TEST_F(QueryAnalysisWriterTest, DiffExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON(std::string(BSONObjMaxUserSize, 'a') << 1); + auto postImage = BSONObj(); + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 0); +} + +} // namespace +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/range_deleter_service.cpp b/src/mongo/db/s/range_deleter_service.cpp index acabcc40bea..ca80b24aa3e 100644 --- a/src/mongo/db/s/range_deleter_service.cpp +++ b/src/mongo/db/s/range_deleter_service.cpp @@ -86,6 +86,50 @@ RangeDeleterService* RangeDeleterService::get(OperationContext* opCtx) { return get(opCtx->getServiceContext()); } +RangeDeleterService::ReadyRangeDeletionsProcessor::ReadyRangeDeletionsProcessor( + OperationContext* opCtx) + : _thread([this] { _runRangeDeletions(); }) {} + +RangeDeleterService::ReadyRangeDeletionsProcessor::~ReadyRangeDeletionsProcessor() { + shutdown(); + invariant(_thread.joinable()); + _thread.join(); + invariant(!_threadOpCtxHolder, + "Thread operation context is still alive after joining main thread"); +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::shutdown() { + stdx::lock_guard<Latch> lock(_mutex); + if (_state == kStopped) + return; + + _state = kStopped; + + if (_threadOpCtxHolder) { + stdx::lock_guard<Client> scopedClientLock(*_threadOpCtxHolder->getClient()); + _threadOpCtxHolder->markKilled(ErrorCodes::Interrupted); + } +} + +bool RangeDeleterService::ReadyRangeDeletionsProcessor::_stopRequested() const { + stdx::unique_lock<Latch> lock(_mutex); + return _state == kStopped; +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::emplaceRangeDeletion( + const RangeDeletionTask& rdt) { + stdx::unique_lock<Latch> lock(_mutex); + invariant(_state == kRunning); + _queue.push(rdt); + _condVar.notify_all(); +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::_completedRangeDeletion() { + stdx::unique_lock<Latch> lock(_mutex); + dassert(!_queue.empty()); + _queue.pop(); +} + void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { Client::initThread(kRangeDeletionThreadName); { @@ -93,21 +137,22 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { cc().setSystemOperationKillableByStepdown(lk); } - auto opCtx = [&] { - stdx::unique_lock<Latch> lock(_mutex); + { + stdx::lock_guard<Latch> lock(_mutex); + if (_state != kRunning) { + return; + } _threadOpCtxHolder = cc().makeOperationContext(); - _condVar.notify_all(); - return (*_threadOpCtxHolder).get(); - }(); + } - opCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE(); + auto opCtx = _threadOpCtxHolder.get(); ON_BLOCK_EXIT([this]() { - stdx::unique_lock<Latch> lock(_mutex); + stdx::lock_guard<Latch> lock(_mutex); _threadOpCtxHolder.reset(); }); - while (opCtx->checkForInterruptNoAssert().isOK()) { + while (!_stopRequested()) { { stdx::unique_lock<Latch> lock(_mutex); try { @@ -223,7 +268,7 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { // Release the thread only in case the operation context has been interrupted, as // interruption only happens on shutdown/stepdown (this is fine because range // deletions will be resumed on the next step up) - if (!opCtx->checkForInterruptNoAssert().isOK()) { + if (_stopRequested()) { break; } @@ -237,6 +282,17 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { } } +void RangeDeleterService::onStartup(OperationContext* opCtx) { + if (disableResumableRangeDeleter.load() || + !feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { + return; + } + + auto opObserverRegistry = + checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver()); + opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>()); +} + void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long term) { if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { return; @@ -249,22 +305,14 @@ void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long te return; } + // Wait until all tasks and thread from previous term drain + _joinAndResetState(); + auto lock = _acquireMutexUnconditionally(); dassert(_state == kDown, "Service expected to be down before stepping up"); _state = kInitializing; - if (_executor) { - // Join previously shutted down executor before reinstantiating it - _executor->join(); - _executor.reset(); - } else { - // Initializing the op observer, only executed once at the first step-up - auto opObserverRegistry = - checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver()); - opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>()); - } - const std::string kExecName("RangeDeleterServiceExecutor"); auto net = executor::makeNetworkInterface(kExecName); auto pool = std::make_unique<executor::NetworkInterfaceThreadPool>(net.get()); @@ -303,9 +351,14 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx LOGV2(6834800, "Resubmitting range deletion tasks"); - ScopedRangeDeleterLock rangeDeleterLock(opCtx); - DBDirectClient client(opCtx); + // The Scoped lock is needed to serialize with concurrent range deletions + ScopedRangeDeleterLock rangeDeleterLock(opCtx, MODE_S); + // The collection lock is needed to serialize with donors trying to + // schedule local range deletions by updating the 'pending' field + AutoGetCollection rangeDeletionLock( + opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S); + DBDirectClient client(opCtx); int nRescheduledTasks = 0; // (1) register range deletion tasks marked as "processing" @@ -370,47 +423,55 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx .semi(); } -void RangeDeleterService::_stopService(bool joinExecutor) { - if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { - return; - } +void RangeDeleterService::_joinAndResetState() { + invariant(_state == kDown); + // Join the thread spawned on step-up to resume range deletions + _stepUpCompletedFuture.getNoThrow().ignore(); - { - auto lock = _acquireMutexUnconditionally(); - _state = kDown; - if (_initOpCtxHolder) { - stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient()); - _initOpCtxHolder->markKilled(ErrorCodes::Interrupted); - } + // Join and destruct the executor + if (_executor) { + _executor->join(); + _executor.reset(); } - // Join the thread spawned on step-up to resume range deletions - _stepUpCompletedFuture.getNoThrow().ignore(); + // Join and destruct the processor + _readyRangeDeletionsProcessorPtr.reset(); + + // Clear range deletions potentially created during recovery + _rangeDeletionTasks.clear(); +} +void RangeDeleterService::_stopService() { auto lock = _acquireMutexUnconditionally(); + if (_state == kDown) + return; + + _state = kDown; + if (_initOpCtxHolder) { + stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient()); + _initOpCtxHolder->markKilled(ErrorCodes::Interrupted); + } - // It may happen for the `onStepDown` hook to be invoked on a SECONDARY node transitioning - // to ROLLBACK, hence the executor may have never been initialized if (_executor) { _executor->shutdown(); - if (joinExecutor) { - _executor->join(); - } } - // Destroy the range deletion processor in order to stop range deletions - _readyRangeDeletionsProcessorPtr.reset(); + // Shutdown the range deletion processor to interrupt range deletions + if (_readyRangeDeletionsProcessorPtr) { + _readyRangeDeletionsProcessorPtr->shutdown(); + } // Clear range deletion tasks map in order to notify potential waiters on completion futures _rangeDeletionTasks.clear(); } void RangeDeleterService::onStepDown() { - _stopService(false /* joinExecutor */); + _stopService(); } void RangeDeleterService::onShutdown() { - _stopService(true /* joinExecutor */); + _stopService(); + _joinAndResetState(); } BSONObj RangeDeleterService::dumpState() { @@ -472,10 +533,9 @@ SharedSemiFuture<void> RangeDeleterService::registerTask( .then([this, rdt = rdt]() { // Step 3: schedule the actual range deletion task auto lock = _acquireMutexUnconditionally(); - invariant( - _readyRangeDeletionsProcessorPtr || _state == kDown, - "The range deletions processor must be instantiated if the state != kDown"); if (_state != kDown) { + invariant(_readyRangeDeletionsProcessorPtr, + "The range deletions processor is not initialized"); _readyRangeDeletionsProcessorPtr->emplaceRangeDeletion(rdt); } }); diff --git a/src/mongo/db/s/range_deleter_service.h b/src/mongo/db/s/range_deleter_service.h index 68cca97454c..c816c8b9db2 100644 --- a/src/mongo/db/s/range_deleter_service.h +++ b/src/mongo/db/s/range_deleter_service.h @@ -110,58 +110,37 @@ private: */ class ReadyRangeDeletionsProcessor { public: - ReadyRangeDeletionsProcessor(OperationContext* opCtx) { - _thread = stdx::thread([this] { _runRangeDeletions(); }); - stdx::unique_lock<Latch> lock(_mutex); - opCtx->waitForConditionOrInterrupt( - _condVar, lock, [&] { return _threadOpCtxHolder.is_initialized(); }); - } - - ~ReadyRangeDeletionsProcessor() { - { - stdx::unique_lock<Latch> lock(_mutex); - // The `_threadOpCtxHolder` may have been already reset/interrupted in case the - // thread got interrupted due to stepdown - if (_threadOpCtxHolder) { - stdx::lock_guard<Client> scopedClientLock(*(*_threadOpCtxHolder)->getClient()); - if ((*_threadOpCtxHolder)->checkForInterruptNoAssert().isOK()) { - (*_threadOpCtxHolder)->markKilled(ErrorCodes::Interrupted); - } - } - _condVar.notify_all(); - } + ReadyRangeDeletionsProcessor(OperationContext* opCtx); + ~ReadyRangeDeletionsProcessor(); - if (_thread.joinable()) { - _thread.join(); - } - } + /* + * Interrupt ongoing range deletions + */ + void shutdown(); /* * Schedule a range deletion at the end of the queue */ - void emplaceRangeDeletion(const RangeDeletionTask& rdt) { - stdx::unique_lock<Latch> lock(_mutex); - _queue.push(rdt); - _condVar.notify_all(); - } + void emplaceRangeDeletion(const RangeDeletionTask& rdt); private: /* + * Return true if this processor have been shutted down + */ + bool _stopRequested() const; + + /* * Remove a range deletion from the head of the queue. Supposed to be called only once a * range deletion successfully finishes. */ - void _completedRangeDeletion() { - stdx::unique_lock<Latch> lock(_mutex); - dassert(!_queue.empty()); - _queue.pop(); - } + void _completedRangeDeletion(); /* * Code executed by the internal thread */ void _runRangeDeletions(); - Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor"); + mutable Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor"); /* * Condition variable notified when: @@ -175,10 +154,13 @@ private: std::queue<RangeDeletionTask> _queue; /* Pointer to the (one and only) operation context used by the thread */ - boost::optional<ServiceContext::UniqueOperationContext> _threadOpCtxHolder; + ServiceContext::UniqueOperationContext _threadOpCtxHolder; /* Thread consuming the range deletions queue */ stdx::thread _thread; + + enum State { kRunning, kStopped }; + State _state{kRunning}; }; // Keeping track of per-collection registered range deletion tasks @@ -253,6 +235,7 @@ public: const ChunkRange& range); /* ReplicaSetAwareServiceShardSvr implemented methods */ + void onStartup(OperationContext* opCtx) override; void onStepUpComplete(OperationContext* opCtx, long long term) override; void onStepDown() override; void onShutdown() override; @@ -276,17 +259,21 @@ public: std::unique_ptr<ReadyRangeDeletionsProcessor> _readyRangeDeletionsProcessorPtr; private: + /* Join all threads and executor and reset the in memory state of the service + * Used for onStartUpBegin and on onShutdown + */ + void _joinAndResetState(); + /* Asynchronously register range deletions on the service. To be called on on step-up */ void _recoverRangeDeletionsOnStepUp(OperationContext* opCtx); - /* Called by shutdown/stepdown hooks to reset the service */ - void _stopService(bool joinExecutor); + /* Called by shutdown/stepdown hooks to interrupt the service */ + void _stopService(); /* ReplicaSetAwareServiceShardSvr "empty implemented" methods */ - void onStartup(OperationContext* opCtx) override final{}; void onInitialDataAvailable(OperationContext* opCtx, bool isMajorityDataAvailable) override final {} - void onStepUpBegin(OperationContext* opCtx, long long term) override final {} + void onStepUpBegin(OperationContext* opCtx, long long term) override final{}; void onBecomeArbiter() override final {} }; diff --git a/src/mongo/db/s/range_deleter_service_op_observer.cpp b/src/mongo/db/s/range_deleter_service_op_observer.cpp index 9c1e51f9a6f..10bbbb6a924 100644 --- a/src/mongo/db/s/range_deleter_service_op_observer.cpp +++ b/src/mongo/db/s/range_deleter_service_op_observer.cpp @@ -35,6 +35,9 @@ #include "mongo/db/s/range_deleter_service.h" #include "mongo/db/s/range_deletion_task_gen.h" #include "mongo/db/update/update_oplog_entry_serialization.h" +#include "mongo/logv2/log.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kShardingRangeDeleter namespace mongo { namespace { @@ -54,10 +57,15 @@ void registerTaskWithOngoingQueriesOnOpLogEntryCommit(OperationContext* opCtx, (void)RangeDeleterService::get(opCtx)->registerTask( rdt, std::move(waitForActiveQueriesToComplete)); } catch (const DBException& ex) { - dassert(ex.code() == ErrorCodes::NotYetInitialized, - str::stream() << "No error different from `NotYetInitialized` is expected " - "to be propagated to the range deleter observer. Got error: " - << ex.toStatus()); + if (ex.code() != ErrorCodes::NotYetInitialized && + !ErrorCodes::isA<ErrorCategory::NotPrimaryError>(ex.code())) { + LOGV2_WARNING(7092800, + "No error different from `NotYetInitialized` or `NotPrimaryError` " + "category is expected to be propagated to the range deleter " + "observer. Range deletion task not registered.", + "error"_attr = redact(ex), + "task"_attr = rdt); + } } }); } diff --git a/src/mongo/db/s/range_deleter_service_test.cpp b/src/mongo/db/s/range_deleter_service_test.cpp index d2406120483..0807867da2d 100644 --- a/src/mongo/db/s/range_deleter_service_test.cpp +++ b/src/mongo/db/s/range_deleter_service_test.cpp @@ -45,6 +45,7 @@ void RangeDeleterServiceTest::setUp() { ShardServerTestFixture::setUp(); WaitForMajorityService::get(getServiceContext()).startup(getServiceContext()); opCtx = operationContext(); + RangeDeleterService::get(opCtx)->onStartup(opCtx); RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L); RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING(); @@ -275,12 +276,6 @@ TEST_F(RangeDeleterServiceTest, ScheduledTaskInvalidatedOnStepDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - // Re-enable the service for clean teardown - rds->onStepUpComplete(opCtx, 0L); - rds->_waitForRangeDeleterServiceUp_FOR_TESTING(); - }); - try { completionFuture.get(opCtx); } catch (const ExceptionForCat<ErrorCategory::Interruption>&) { @@ -293,12 +288,6 @@ TEST_F(RangeDeleterServiceTest, NoActionPossibleIfServiceIsDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - // Re-enable the service for clean teardown - rds->onStepUpComplete(opCtx, 0L); - rds->_waitForRangeDeleterServiceUp_FOR_TESTING(); - }); - auto taskWithOngoingQueries = createRangeDeletionTaskWithOngoingQueries( uuidCollA, BSON(kShardKey << 0), BSON(kShardKey << 10), CleanWhenEnum::kDelayed); @@ -886,10 +875,6 @@ TEST_F(RangeDeleterServiceTest, WaitForOngoingQueriesInvalidatedOnStepDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - rds->onStepUpComplete(opCtx, 0L); // Re-enable the service - }); - try { completionFuture.get(opCtx); } catch (const ExceptionForCat<ErrorCategory::Interruption>&) { diff --git a/src/mongo/db/s/range_deletion_util.cpp b/src/mongo/db/s/range_deletion_util.cpp index 3a8b512ddba..707516db64f 100644 --- a/src/mongo/db/s/range_deletion_util.cpp +++ b/src/mongo/db/s/range_deletion_util.cpp @@ -322,7 +322,7 @@ Status deleteRangeInBatches(OperationContext* opCtx, const ChunkRange& range) { suspendRangeDeletion.pauseWhileSet(opCtx); - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); bool allDocsRemoved = false; // Delete all batches in this range unless a stepdown error occurs. Do not yield the @@ -584,7 +584,7 @@ void persistUpdatedNumOrphans(OperationContext* opCtx, const auto query = getQueryFilterForRangeDeletionTask(collectionUuid, range); try { PersistentTaskStore<RangeDeletionTask> store(NamespaceString::kRangeDeletionNamespace); - ScopedRangeDeleterLock rangeDeleterLock(opCtx, collectionUuid); + ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_IX); // The DBDirectClient will not retry WriteConflictExceptions internally while holding an X // mode lock, so we need to retry at this level. writeConflictRetry( diff --git a/src/mongo/db/s/shard_server_op_observer.cpp b/src/mongo/db/s/shard_server_op_observer.cpp index 744c79a1eca..ae059d19512 100644 --- a/src/mongo/db/s/shard_server_op_observer.cpp +++ b/src/mongo/db/s/shard_server_op_observer.cpp @@ -509,20 +509,24 @@ void ShardServerOpObserver::onModifyShardedCollectionGlobalIndexCatalogEntry( IDLParserContext("onModifyShardedCollectionGlobalIndexCatalogEntry"), indexDoc["entry"].Obj()); auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp(); - opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry](auto _) { + auto uuid = uassertStatusOK( + UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName])); + opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry, uuid](auto _) { AutoGetCollection autoColl(opCtx, nss, MODE_IX); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, nss, CSRAcquisitionMode::kExclusive) - ->addIndex(opCtx, indexEntry, indexVersion); + ->addIndex(opCtx, indexEntry, {uuid, indexVersion}); }); } else { auto indexName = indexDoc["entry"][IndexCatalogType::kNameFieldName].str(); auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp(); - opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion](auto _) { + auto uuid = uassertStatusOK( + UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName])); + opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion, uuid](auto _) { AutoGetCollection autoColl(opCtx, nss, MODE_IX); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, nss, CSRAcquisitionMode::kExclusive) - ->removeIndex(opCtx, indexName, indexVersion); + ->removeIndex(opCtx, indexName, {uuid, indexVersion}); }); } } diff --git a/src/mongo/db/s/sharding_recovery_service.cpp b/src/mongo/db/s/sharding_recovery_service.cpp index d47e4056534..02dab8f3773 100644 --- a/src/mongo/db/s/sharding_recovery_service.cpp +++ b/src/mongo/db/s/sharding_recovery_service.cpp @@ -532,7 +532,7 @@ void ShardingRecoveryService::recoverIndexesCatalog(OperationContext* opCtx) { AutoGetCollection collLock(opCtx, nss, MODE_X); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, collLock->ns(), CSRAcquisitionMode::kExclusive) - ->addIndex(opCtx, indexEntry, indexVersion); + ->addIndex(opCtx, indexEntry, {indexEntry.getCollectionUUID(), indexVersion}); } } LOGV2_DEBUG(6686502, 2, "Recovered all index versions"); diff --git a/src/mongo/db/s/sharding_write_router_bm.cpp b/src/mongo/db/s/sharding_write_router_bm.cpp index e92d8292052..89fa3ded04d 100644 --- a/src/mongo/db/s/sharding_write_router_bm.cpp +++ b/src/mongo/db/s/sharding_write_router_bm.cpp @@ -38,6 +38,7 @@ #include "mongo/db/s/collection_metadata.h" #include "mongo/db/s/collection_sharding_runtime.h" #include "mongo/db/s/collection_sharding_state_factory_shard.h" +#include "mongo/db/s/collection_sharding_state_factory_standalone.h" #include "mongo/db/s/operation_sharding_state.h" #include "mongo/db/s/sharding_state.h" #include "mongo/db/s/sharding_write_router.h" @@ -214,6 +215,10 @@ void BM_UnshardedDestinedRecipient(benchmark::State& state) { serviceContext->registerClientObserver(std::make_unique<LockerNoopClientObserver>()); const auto opCtx = client->makeOperationContext(); + CollectionShardingStateFactory::set( + opCtx->getServiceContext(), + std::make_unique<CollectionShardingStateFactoryStandalone>(opCtx->getServiceContext())); + const auto catalogCache = CatalogCacheMock::make(); for (auto keepRunning : state) { diff --git a/src/mongo/db/s/transaction_coordinator_util.cpp b/src/mongo/db/s/transaction_coordinator_util.cpp index 736e5fb1cf4..9bd3114a84f 100644 --- a/src/mongo/db/s/transaction_coordinator_util.cpp +++ b/src/mongo/db/s/transaction_coordinator_util.cpp @@ -445,7 +445,7 @@ Future<repl::OpTime> persistDecision(txn::AsyncWorkScheduler& scheduler, // Do not acquire a storage ticket in order to avoid unnecessary serialization // with other prepared transactions that are holding a storage ticket // themselves; see SERVER-60682. - SetTicketAquisitionPriorityForLock setTicketAquisition( + SetAdmissionPriorityForLock setTicketAquisition( opCtx, AdmissionContext::Priority::kImmediate); getTransactionCoordinatorWorkerCurOpRepository()->set( opCtx, lsid, txnNumberAndRetryCounter, CoordinatorAction::kWritingDecision); |