summaryrefslogtreecommitdiff
path: root/src/mongo/db/s
diff options
context:
space:
mode:
authorJennifer Peshansky <jennifer.peshansky@mongodb.com>2022-11-03 16:13:20 +0000
committerJennifer Peshansky <jennifer.peshansky@mongodb.com>2022-11-03 16:13:20 +0000
commite74d2910bbe76790ad131d53fee277829cd95982 (patch)
treecabe148764529c9623652374fbc36323a550cd44 /src/mongo/db/s
parent280145e9940729480bb8a35453d4056afac87641 (diff)
parentba467f46cc1bc49965e1d72b541eff0cf1d7b22e (diff)
downloadmongo-jenniferpeshansky/SERVER-70854.tar.gz
Merge branch 'master' into jenniferpeshansky/SERVER-70854jenniferpeshansky/SERVER-70854
Diffstat (limited to 'src/mongo/db/s')
-rw-r--r--src/mongo/db/s/SConscript37
-rw-r--r--src/mongo/db/s/balancer/balancer.cpp11
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp19
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h7
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp6
-rw-r--r--src/mongo/db/s/balancer/balancer_policy.cpp31
-rw-r--r--src/mongo/db/s/balancer/balancer_policy.h5
-rw-r--r--src/mongo/db/s/balancer_stats_registry.cpp28
-rw-r--r--src/mongo/db/s/balancer_stats_registry.h15
-rw-r--r--src/mongo/db/s/cluster_count_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_find_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_pipeline_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_write_cmd_d.cpp14
-rw-r--r--src/mongo/db/s/collection_sharding_runtime.cpp16
-rw-r--r--src/mongo/db/s/collection_sharding_runtime.h8
-rw-r--r--src/mongo/db/s/collection_sharding_runtime_test.cpp9
-rw-r--r--src/mongo/db/s/config/sharding_catalog_manager.h3
-rw-r--r--src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp28
-rw-r--r--src/mongo/db/s/query_analysis_op_observer.cpp12
-rw-r--r--src/mongo/db/s/query_analysis_writer.cpp695
-rw-r--r--src/mongo/db/s/query_analysis_writer.h214
-rw-r--r--src/mongo/db/s/query_analysis_writer_test.cpp1281
-rw-r--r--src/mongo/db/s/range_deleter_service.cpp154
-rw-r--r--src/mongo/db/s/range_deleter_service.h67
-rw-r--r--src/mongo/db/s/range_deleter_service_op_observer.cpp16
-rw-r--r--src/mongo/db/s/range_deleter_service_test.cpp17
-rw-r--r--src/mongo/db/s/range_deletion_util.cpp4
-rw-r--r--src/mongo/db/s/shard_server_op_observer.cpp12
-rw-r--r--src/mongo/db/s/sharding_recovery_service.cpp2
-rw-r--r--src/mongo/db/s/sharding_write_router_bm.cpp5
-rw-r--r--src/mongo/db/s/transaction_coordinator_util.cpp2
31 files changed, 2543 insertions, 190 deletions
diff --git a/src/mongo/db/s/SConscript b/src/mongo/db/s/SConscript
index fa54218910d..3248c1b296d 100644
--- a/src/mongo/db/s/SConscript
+++ b/src/mongo/db/s/SConscript
@@ -11,7 +11,6 @@ env = env.Clone()
env.Library(
target='sharding_api_d',
source=[
- 'balancer_stats_registry.cpp',
'collection_metadata.cpp',
'collection_sharding_state_factory_standalone.cpp',
'collection_sharding_state.cpp',
@@ -36,14 +35,25 @@ env.Library(
LIBDEPS_PRIVATE=[
'$BUILD_DIR/mongo/db/catalog/index_catalog',
'$BUILD_DIR/mongo/db/concurrency/lock_manager',
- '$BUILD_DIR/mongo/db/dbdirectclient',
- '$BUILD_DIR/mongo/db/repl/replica_set_aware_service',
'$BUILD_DIR/mongo/db/server_base',
'$BUILD_DIR/mongo/db/write_block_bypass',
],
)
env.Library(
+ target='balancer_stats_registry',
+ source=[
+ 'balancer_stats_registry.cpp',
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/db/dbdirectclient',
+ '$BUILD_DIR/mongo/db/repl/replica_set_aware_service',
+ '$BUILD_DIR/mongo/db/shard_role',
+ '$BUILD_DIR/mongo/s/grid',
+ ],
+)
+
+env.Library(
target='sharding_catalog',
source=[
'global_index_ddl_util.cpp',
@@ -64,6 +74,22 @@ env.Library(
)
env.Library(
+ target="query_analysis_writer",
+ source=[
+ "query_analysis_writer.cpp",
+ ],
+ LIBDEPS_PRIVATE=[
+ "$BUILD_DIR/mongo/db/dbdirectclient",
+ '$BUILD_DIR/mongo/db/ops/write_ops_parsers',
+ '$BUILD_DIR/mongo/db/server_base',
+ '$BUILD_DIR/mongo/db/service_context',
+ '$BUILD_DIR/mongo/db/shard_role',
+ '$BUILD_DIR/mongo/idl/idl_parser',
+ '$BUILD_DIR/mongo/s/analyze_shard_key_common',
+ ],
+)
+
+env.Library(
target='sharding_runtime_d',
source=[
'active_migrations_registry.cpp',
@@ -228,6 +254,8 @@ env.Library(
'$BUILD_DIR/mongo/db/transaction/transaction_operations',
'$BUILD_DIR/mongo/s/common_s',
'$BUILD_DIR/mongo/util/future_util',
+ 'balancer_stats_registry',
+ 'query_analysis_writer',
'sharding_catalog',
'sharding_logging',
],
@@ -550,6 +578,7 @@ env.Library(
'$BUILD_DIR/mongo/s/sharding_initialization',
'$BUILD_DIR/mongo/s/sharding_router_api',
'$BUILD_DIR/mongo/s/startup_initialization',
+ 'balancer_stats_registry',
'forwardable_operation_metadata',
'sharding_catalog',
'sharding_logging',
@@ -653,6 +682,7 @@ env.CppUnitTest(
'op_observer_sharding_test.cpp',
'operation_sharding_state_test.cpp',
'persistent_task_queue_test.cpp',
+ 'query_analysis_writer_test.cpp',
'range_deleter_service_test.cpp',
'range_deleter_service_test_util.cpp',
'range_deleter_service_op_observer_test.cpp',
@@ -734,6 +764,7 @@ env.CppUnitTest(
'$BUILD_DIR/mongo/executor/thread_pool_task_executor_test_fixture',
'$BUILD_DIR/mongo/s/catalog/sharding_catalog_client_mock',
'$BUILD_DIR/mongo/s/sharding_router_test_fixture',
+ 'query_analysis_writer',
'shard_server_test_fixture',
'sharding_catalog',
'sharding_commands_d',
diff --git a/src/mongo/db/s/balancer/balancer.cpp b/src/mongo/db/s/balancer/balancer.cpp
index f1c96cc97cf..9103ec7f04c 100644
--- a/src/mongo/db/s/balancer/balancer.cpp
+++ b/src/mongo/db/s/balancer/balancer.cpp
@@ -281,7 +281,7 @@ Balancer::Balancer()
std::make_unique<BalancerChunkSelectionPolicyImpl>(_clusterStats.get(), _random)),
_commandScheduler(std::make_unique<BalancerCommandsSchedulerImpl>()),
_defragmentationPolicy(std::make_unique<BalancerDefragmentationPolicyImpl>(
- _clusterStats.get(), _random, [this]() { _onActionsStreamPolicyStateUpdate(); })),
+ _clusterStats.get(), [this]() { _onActionsStreamPolicyStateUpdate(); })),
_clusterChunksResizePolicy(std::make_unique<ClusterChunksResizePolicyImpl>(
[this] { _onActionsStreamPolicyStateUpdate(); })) {}
@@ -1085,7 +1085,7 @@ int Balancer::_moveChunks(OperationContext* opCtx,
opCtx, migrateInfo.uuid, repl::ReadConcernLevel::kMajorityReadConcern);
ShardingCatalogManager::get(opCtx)->splitOrMarkJumbo(
- opCtx, collection.getNss(), migrateInfo.minKey);
+ opCtx, collection.getNss(), migrateInfo.minKey, migrateInfo.getMaxChunkSizeBytes());
continue;
}
@@ -1151,7 +1151,12 @@ BalancerCollectionStatusResponse Balancer::getBalancerStatusForNs(OperationConte
uasserted(ErrorCodes::NamespaceNotSharded, "Collection unsharded or undefined");
}
- const auto maxChunkSizeMB = getMaxChunkSizeMB(opCtx, coll);
+
+ const auto maxChunkSizeBytes = getMaxChunkSizeBytes(opCtx, coll);
+ double maxChunkSizeMB = (double)maxChunkSizeBytes / (1024 * 1024);
+ // Keep only 2 decimal digits to return a readable value
+ maxChunkSizeMB = std::ceil(maxChunkSizeMB * 100.0) / 100.0;
+
BalancerCollectionStatusResponse response(maxChunkSizeMB, true /*balancerCompliant*/);
auto setViolationOnResponse = [&response](const StringData& reason,
const boost::optional<BSONObj>& details =
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
index 1c5a88a6100..c8d51184bc3 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
@@ -398,7 +398,8 @@ public:
std::move(collectionChunks),
std::move(shardInfos),
std::move(collectionZones),
- smallChunkSizeThresholdBytes));
+ smallChunkSizeThresholdBytes,
+ maxChunkSizeBytes));
}
DefragmentationPhaseEnum getType() const override {
@@ -462,7 +463,7 @@ public:
auto smallChunkVersion = getShardVersion(opCtx, nextSmallChunk->shard, _nss);
_outstandingMigrations.emplace_back(nextSmallChunk, targetSibling);
return _outstandingMigrations.back().asMigrateInfo(
- _uuid, _nss, smallChunkVersion.placementVersion());
+ _uuid, _nss, smallChunkVersion.placementVersion(), _maxChunkSizeBytes);
}
return boost::none;
@@ -701,7 +702,8 @@ private:
MigrateInfo asMigrateInfo(const UUID& collUuid,
const NamespaceString& nss,
- const ChunkVersion& version) const {
+ const ChunkVersion& version,
+ uint64_t maxChunkSizeBytes) const {
return MigrateInfo(chunkToMergeWith->shard,
chunkToMove->shard,
nss,
@@ -709,7 +711,8 @@ private:
chunkToMove->range.getMin(),
chunkToMove->range.getMax(),
version,
- ForceJumbo::kForceBalancer);
+ ForceJumbo::kForceBalancer,
+ maxChunkSizeBytes);
}
ChunkRange asMergedRange() const {
@@ -774,6 +777,8 @@ private:
const int64_t _smallChunkSizeThresholdBytes;
+ const uint64_t _maxChunkSizeBytes;
+
bool _aborted{false};
DefragmentationPhaseEnum _nextPhase{DefragmentationPhaseEnum::kMergeChunks};
@@ -783,7 +788,8 @@ private:
std::vector<ChunkType>&& collectionChunks,
stdx::unordered_map<ShardId, ShardInfo>&& shardInfos,
ZoneInfo&& collectionZones,
- uint64_t smallChunkSizeThresholdBytes)
+ uint64_t smallChunkSizeThresholdBytes,
+ uint64_t maxChunkSizeBytes)
: _nss(nss),
_uuid(uuid),
_collectionChunks(),
@@ -794,7 +800,8 @@ private:
_actionableMerges(),
_outstandingMerges(),
_zoneInfo(std::move(collectionZones)),
- _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes) {
+ _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes),
+ _maxChunkSizeBytes(maxChunkSizeBytes) {
// Load the collection routing table in a std::list to ease later manipulation
for (auto&& chunk : collectionChunks) {
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
index bc41346ca7f..bab4a6c58cf 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
@@ -72,9 +72,10 @@ class BalancerDefragmentationPolicyImpl : public BalancerDefragmentationPolicy {
public:
BalancerDefragmentationPolicyImpl(ClusterStatistics* clusterStats,
- BalancerRandomSource& random,
const std::function<void()>& onStateUpdated)
- : _clusterStats(clusterStats), _random(random), _onStateUpdated(onStateUpdated) {}
+ : _clusterStats(clusterStats),
+ _random(std::random_device{}()),
+ _onStateUpdated(onStateUpdated) {}
~BalancerDefragmentationPolicyImpl() {}
@@ -145,7 +146,7 @@ private:
ClusterStatistics* const _clusterStats;
- BalancerRandomSource& _random;
+ BalancerRandomSource _random;
const std::function<void()> _onStateUpdated;
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
index 8646b5d965a..e9370bbb4d3 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
@@ -29,7 +29,6 @@
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/s/balancer/balancer_defragmentation_policy_impl.h"
-#include "mongo/db/s/balancer/balancer_random.h"
#include "mongo/db/s/balancer/cluster_statistics_mock.h"
#include "mongo/db/s/config/config_server_test_fixture.h"
#include "mongo/idl/server_parameter_test_util.h"
@@ -75,9 +74,7 @@ protected:
const std::function<void()> onDefragmentationStateUpdated = [] {};
BalancerDefragmentationPolicyTest()
- : _clusterStats(),
- _random(std::random_device{}()),
- _defragmentationPolicy(&_clusterStats, _random, onDefragmentationStateUpdated) {}
+ : _clusterStats(), _defragmentationPolicy(&_clusterStats, onDefragmentationStateUpdated) {}
CollectionType setupCollectionWithPhase(
const std::vector<ChunkType>& chunkList,
@@ -137,7 +134,6 @@ protected:
}
ClusterStatisticsMock _clusterStats;
- BalancerRandomSource _random;
BalancerDefragmentationPolicyImpl _defragmentationPolicy;
ShardStatistics buildShardStats(ShardId id,
diff --git a/src/mongo/db/s/balancer/balancer_policy.cpp b/src/mongo/db/s/balancer/balancer_policy.cpp
index 4e0ff4c3ccb..a8ad96ab2c7 100644
--- a/src/mongo/db/s/balancer/balancer_policy.cpp
+++ b/src/mongo/db/s/balancer/balancer_policy.cpp
@@ -36,6 +36,7 @@
#include "mongo/db/s/balancer/type_migration.h"
#include "mongo/logv2/log.h"
+#include "mongo/s/balancer_configuration.h"
#include "mongo/s/catalog/type_shard.h"
#include "mongo/s/catalog/type_tags.h"
#include "mongo/s/grid.h"
@@ -450,7 +451,16 @@ MigrateInfosWithReason BalancerPolicy::balance(
}
invariant(to != stat.shardId);
- migrations.emplace_back(to, distribution.nss(), chunk, ForceJumbo::kForceBalancer);
+
+ auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> {
+ if (collDataSizeInfo.has_value()) {
+ return collDataSizeInfo->maxChunkSizeBytes;
+ }
+ return boost::none;
+ }();
+
+ migrations.emplace_back(
+ to, distribution.nss(), chunk, ForceJumbo::kForceBalancer, maxChunkSizeBytes);
if (firstReason == MigrationReason::none) {
firstReason = MigrationReason::drain;
}
@@ -513,11 +523,20 @@ MigrateInfosWithReason BalancerPolicy::balance(
}
invariant(to != stat.shardId);
+
+ auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> {
+ if (collDataSizeInfo.has_value()) {
+ return collDataSizeInfo->maxChunkSizeBytes;
+ }
+ return boost::none;
+ }();
+
migrations.emplace_back(to,
distribution.nss(),
chunk,
forceJumbo ? ForceJumbo::kForceBalancer
- : ForceJumbo::kDoNotForce);
+ : ForceJumbo::kDoNotForce,
+ maxChunkSizeBytes);
if (firstReason == MigrationReason::none) {
firstReason = MigrationReason::zoneViolation;
}
@@ -796,7 +815,8 @@ string ZoneRange::toString() const {
MigrateInfo::MigrateInfo(const ShardId& a_to,
const NamespaceString& a_nss,
const ChunkType& a_chunk,
- const ForceJumbo a_forceJumbo)
+ const ForceJumbo a_forceJumbo,
+ boost::optional<int64_t> maxChunkSizeBytes)
: nss(a_nss), uuid(a_chunk.getCollectionUUID()) {
invariant(a_to.isValid());
@@ -807,6 +827,7 @@ MigrateInfo::MigrateInfo(const ShardId& a_to,
maxKey = a_chunk.getMax();
version = a_chunk.getVersion();
forceJumbo = a_forceJumbo;
+ optMaxChunkSizeBytes = maxChunkSizeBytes;
}
MigrateInfo::MigrateInfo(const ShardId& a_to,
@@ -858,6 +879,10 @@ string MigrateInfo::toString() const {
<< ", to " << to;
}
+boost::optional<int64_t> MigrateInfo::getMaxChunkSizeBytes() const {
+ return optMaxChunkSizeBytes;
+}
+
SplitInfo::SplitInfo(const ShardId& inShardId,
const NamespaceString& inNss,
const ChunkVersion& inCollectionVersion,
diff --git a/src/mongo/db/s/balancer/balancer_policy.h b/src/mongo/db/s/balancer/balancer_policy.h
index bd464ca9499..a9a1f307310 100644
--- a/src/mongo/db/s/balancer/balancer_policy.h
+++ b/src/mongo/db/s/balancer/balancer_policy.h
@@ -60,7 +60,8 @@ struct MigrateInfo {
MigrateInfo(const ShardId& a_to,
const NamespaceString& a_nss,
const ChunkType& a_chunk,
- ForceJumbo a_forceJumbo);
+ ForceJumbo a_forceJumbo,
+ boost::optional<int64_t> maxChunkSizeBytes = boost::none);
MigrateInfo(const ShardId& a_to,
const ShardId& a_from,
@@ -78,6 +79,8 @@ struct MigrateInfo {
std::string toString() const;
+ boost::optional<int64_t> getMaxChunkSizeBytes() const;
+
NamespaceString nss;
UUID uuid;
ShardId to;
diff --git a/src/mongo/db/s/balancer_stats_registry.cpp b/src/mongo/db/s/balancer_stats_registry.cpp
index bfd789acb52..0f11a578c23 100644
--- a/src/mongo/db/s/balancer_stats_registry.cpp
+++ b/src/mongo/db/s/balancer_stats_registry.cpp
@@ -29,6 +29,7 @@
#include "mongo/db/s/balancer_stats_registry.h"
+#include "mongo/db/catalog_raii.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/pipeline/aggregate_command_gen.h"
#include "mongo/db/repl/replication_coordinator.h"
@@ -58,23 +59,6 @@ ThreadPool::Options makeDefaultThreadPoolOptions() {
}
} // namespace
-ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx)
- // TODO SERVER-62491 Use system tenantId for DBLock
- : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX),
- _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_X) {}
-
-// Take DB and Collection lock in mode IX as well as collection UUID lock to serialize with
-// operations that take the above version of the ScopedRangeDeleterLock such as FCV downgrade and
-// BalancerStatsRegistry initialization.
-ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid)
- // TODO SERVER-62491 Use system tenantId for DBLock
- : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX),
- _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_IX),
- _collectionUuidLock(Lock::ResourceLock(
- opCtx,
- ResourceId(RESOURCE_MUTEX, "RangeDeleterCollLock::" + collectionUuid.toString()),
- MODE_X)) {}
-
const ReplicaSetAwareServiceRegistry::Registerer<BalancerStatsRegistry>
balancerStatsRegistryRegisterer("BalancerStatsRegistry");
@@ -131,9 +115,13 @@ void BalancerStatsRegistry::initializeAsync(OperationContext* opCtx) {
LOGV2_DEBUG(6419601, 2, "Initializing BalancerStatsRegistry");
try {
- // Lock the range deleter to prevent
- // concurrent modifications of orphans count
- ScopedRangeDeleterLock rangeDeleterLock(opCtx);
+ // Lock the range deleter to prevent concurrent modifications of orphans count
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_S);
+ // The collection lock is needed to serialize with direct writes to
+ // config.rangeDeletions
+ AutoGetCollection rangeDeletionLock(
+ opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S);
+
// Load current ophans count from disk
_loadOrphansCount(opCtx);
LOGV2_DEBUG(6419602, 2, "Completed BalancerStatsRegistry initialization");
diff --git a/src/mongo/db/s/balancer_stats_registry.h b/src/mongo/db/s/balancer_stats_registry.h
index 473285b627c..fdc68ab1021 100644
--- a/src/mongo/db/s/balancer_stats_registry.h
+++ b/src/mongo/db/s/balancer_stats_registry.h
@@ -38,18 +38,19 @@
namespace mongo {
/**
- * Acquires the config db lock in IX mode and the collection lock for config.rangeDeletions in X
- * mode.
+ * Scoped lock to synchronize with the execution of range deletions.
+ * The range-deleter acquires a scoped lock in IX mode while orphans are being deleted.
+ * Acquiring the scoped lock in MODE_X ensures that no orphan counter in `config.rangeDeletions`
+ * entries is going to be updated concurrently.
*/
class ScopedRangeDeleterLock {
public:
- ScopedRangeDeleterLock(OperationContext* opCtx);
- ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid);
+ ScopedRangeDeleterLock(OperationContext* opCtx, LockMode mode)
+ : _resourceLock(opCtx, _mutex.getRid(), mode) {}
private:
- Lock::DBLock _configLock;
- Lock::CollectionLock _rangeDeletionLock;
- boost::optional<Lock::ResourceLock> _collectionUuidLock;
+ const Lock::ResourceLock _resourceLock;
+ static inline const Lock::ResourceMutex _mutex{"ScopedRangeDeleterLock"};
};
/**
diff --git a/src/mongo/db/s/cluster_count_cmd_d.cpp b/src/mongo/db/s/cluster_count_cmd_d.cpp
index 27c64148cc9..a593a299b4f 100644
--- a/src/mongo/db/s/cluster_count_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_count_cmd_d.cpp
@@ -62,6 +62,11 @@ struct ClusterCountCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster count command on a mongod");
+ }
};
ClusterCountCmdBase<ClusterCountCmdD> clusterCountCmdD;
diff --git a/src/mongo/db/s/cluster_find_cmd_d.cpp b/src/mongo/db/s/cluster_find_cmd_d.cpp
index e97f07b70e7..e98315c06ad 100644
--- a/src/mongo/db/s/cluster_find_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_find_cmd_d.cpp
@@ -61,6 +61,11 @@ struct ClusterFindCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster find command on a mongod");
+ }
};
ClusterFindCmdBase<ClusterFindCmdD> clusterFindCmdD;
diff --git a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
index 437e7418355..1b3eed4242a 100644
--- a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
@@ -61,6 +61,11 @@ struct ClusterPipelineCommandD {
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster aggregate command on a mongod");
+ }
+
static AggregateCommandRequest parseAggregationRequest(
OperationContext* opCtx,
const OpMsgRequest& opMsgRequest,
diff --git a/src/mongo/db/s/cluster_write_cmd_d.cpp b/src/mongo/db/s/cluster_write_cmd_d.cpp
index c8f75d69e15..66167ab2b30 100644
--- a/src/mongo/db/s/cluster_write_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_write_cmd_d.cpp
@@ -57,6 +57,11 @@ struct ClusterInsertCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster insert command on a mongod");
+ }
};
ClusterInsertCmdBase<ClusterInsertCmdD> clusterInsertCmdD;
@@ -83,6 +88,10 @@ struct ClusterUpdateCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported, "Explain on a clusterDelete is not supported");
+ }
};
ClusterUpdateCmdBase<ClusterUpdateCmdD> clusterUpdateCmdD;
@@ -109,6 +118,11 @@ struct ClusterDeleteCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster delete command on a mongod");
+ }
};
ClusterDeleteCmdBase<ClusterDeleteCmdD> clusterDeleteCmdD;
diff --git a/src/mongo/db/s/collection_sharding_runtime.cpp b/src/mongo/db/s/collection_sharding_runtime.cpp
index 89ceb0cfd4c..ee530f6ea04 100644
--- a/src/mongo/db/s/collection_sharding_runtime.cpp
+++ b/src/mongo/db/s/collection_sharding_runtime.cpp
@@ -486,8 +486,10 @@ void CollectionShardingRuntime::resetShardVersionRecoverRefreshFuture() {
_shardVersionInRecoverOrRefresh = boost::none;
}
-boost::optional<Timestamp> CollectionShardingRuntime::getIndexVersion(OperationContext* opCtx) {
- return _globalIndexesInfo ? _globalIndexesInfo->getVersion() : boost::none;
+boost::optional<CollectionIndexes> CollectionShardingRuntime::getCollectionIndexes(
+ OperationContext* opCtx) {
+ return _globalIndexesInfo ? boost::make_optional(_globalIndexesInfo->getCollectionIndexes())
+ : boost::none;
}
boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes(
@@ -497,22 +499,22 @@ boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes(
void CollectionShardingRuntime::addIndex(OperationContext* opCtx,
const IndexCatalogType& index,
- const Timestamp& indexVersion) {
+ const CollectionIndexes& collectionIndexes) {
if (_globalIndexesInfo) {
- _globalIndexesInfo->add(index, indexVersion);
+ _globalIndexesInfo->add(index, collectionIndexes);
} else {
IndexCatalogTypeMap indexMap;
indexMap.emplace(index.getName(), index);
- _globalIndexesInfo.emplace(indexVersion, std::move(indexMap));
+ _globalIndexesInfo.emplace(collectionIndexes, std::move(indexMap));
}
}
void CollectionShardingRuntime::removeIndex(OperationContext* opCtx,
const std::string& name,
- const Timestamp& indexVersion) {
+ const CollectionIndexes& collectionIndexes) {
tassert(
7019500, "Index information does not exist on CSR", _globalIndexesInfo.is_initialized());
- _globalIndexesInfo->remove(name, indexVersion);
+ _globalIndexesInfo->remove(name, collectionIndexes);
}
void CollectionShardingRuntime::clearIndexes(OperationContext* opCtx) {
diff --git a/src/mongo/db/s/collection_sharding_runtime.h b/src/mongo/db/s/collection_sharding_runtime.h
index 44e46d0bc4a..786425d8e92 100644
--- a/src/mongo/db/s/collection_sharding_runtime.h
+++ b/src/mongo/db/s/collection_sharding_runtime.h
@@ -229,7 +229,7 @@ public:
/**
* Gets an index version under a lock.
*/
- boost::optional<Timestamp> getIndexVersion(OperationContext* opCtx);
+ boost::optional<CollectionIndexes> getCollectionIndexes(OperationContext* opCtx);
/**
* Gets the index list under a lock.
@@ -241,17 +241,17 @@ public:
*/
void addIndex(OperationContext* opCtx,
const IndexCatalogType& index,
- const Timestamp& indexVersion);
+ const CollectionIndexes& collectionIndexes);
/**
* Removes an index from the shard-role index info under a lock.
*/
void removeIndex(OperationContext* opCtx,
const std::string& name,
- const Timestamp& indexVersion);
+ const CollectionIndexes& collectionIndexes);
/**
- * Clears the shard-role index info and set the indexVersion to boost::none.
+ * Clears the shard-role index info and set the collectionIndexes to boost::none.
*/
void clearIndexes(OperationContext* opCtx);
diff --git a/src/mongo/db/s/collection_sharding_runtime_test.cpp b/src/mongo/db/s/collection_sharding_runtime_test.cpp
index 6e792a8db0e..334225663e3 100644
--- a/src/mongo/db/s/collection_sharding_runtime_test.cpp
+++ b/src/mongo/db/s/collection_sharding_runtime_test.cpp
@@ -377,8 +377,10 @@ public:
AutoGetCollection autoColl(operationContext(), kTestNss, MODE_IX);
_uuid = autoColl.getCollection()->uuid();
- RangeDeleterService::get(operationContext())->onStepUpComplete(operationContext(), 0L);
- RangeDeleterService::get(operationContext())->_waitForRangeDeleterServiceUp_FOR_TESTING();
+ auto opCtx = operationContext();
+ RangeDeleterService::get(opCtx)->onStartup(opCtx);
+ RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L);
+ RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING();
}
void tearDown() override {
@@ -611,8 +613,7 @@ TEST_F(CollectionShardingRuntimeWithCatalogTest, TestGlobalIndexesCache) {
opCtx, kTestNss, "x_1", BSON("x" << 1), BSONObj(), uuid(), indexVersion, boost::none);
ASSERT_EQ(true, csr()->getIndexes(opCtx).is_initialized());
- ASSERT_EQ(indexVersion, *csr()->getIndexes(opCtx)->getVersion());
- ASSERT_EQ(indexVersion, *csr()->getIndexVersion(opCtx));
+ ASSERT_EQ(CollectionIndexes(uuid(), indexVersion), *csr()->getCollectionIndexes(opCtx));
}
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/s/config/sharding_catalog_manager.h b/src/mongo/db/s/config/sharding_catalog_manager.h
index 4e8a64f59a6..fba732dce7a 100644
--- a/src/mongo/db/s/config/sharding_catalog_manager.h
+++ b/src/mongo/db/s/config/sharding_catalog_manager.h
@@ -359,7 +359,8 @@ public:
*/
void splitOrMarkJumbo(OperationContext* opCtx,
const NamespaceString& nss,
- const BSONObj& minKey);
+ const BSONObj& minKey,
+ boost::optional<int64_t> optMaxChunkSizeBytes);
/**
* In a transaction, sets the 'allowMigrations' to the requested state and bumps the collection
diff --git a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
index 2b0692acb5d..93ac3b56099 100644
--- a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
+++ b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
@@ -1775,19 +1775,31 @@ void ShardingCatalogManager::bumpMultipleCollectionVersionsAndChangeMetadataInTx
void ShardingCatalogManager::splitOrMarkJumbo(OperationContext* opCtx,
const NamespaceString& nss,
- const BSONObj& minKey) {
+ const BSONObj& minKey,
+ boost::optional<int64_t> optMaxChunkSizeBytes) {
const auto cm = uassertStatusOK(
Grid::get(opCtx)->catalogCache()->getShardedCollectionRoutingInfoWithRefresh(opCtx, nss));
auto chunk = cm.findIntersectingChunkWithSimpleCollation(minKey);
try {
- const auto splitPoints = uassertStatusOK(shardutil::selectChunkSplitPoints(
- opCtx,
- chunk.getShardId(),
- nss,
- cm.getShardKeyPattern(),
- ChunkRange(chunk.getMin(), chunk.getMax()),
- Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes()));
+ const auto maxChunkSizeBytes = [&]() -> int64_t {
+ if (optMaxChunkSizeBytes.has_value()) {
+ return *optMaxChunkSizeBytes;
+ }
+
+ auto coll = Grid::get(opCtx)->catalogClient()->getCollection(
+ opCtx, nss, repl::ReadConcernLevel::kMajorityReadConcern);
+ return coll.getMaxChunkSizeBytes().value_or(
+ Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes());
+ }();
+
+ const auto splitPoints = uassertStatusOK(
+ shardutil::selectChunkSplitPoints(opCtx,
+ chunk.getShardId(),
+ nss,
+ cm.getShardKeyPattern(),
+ ChunkRange(chunk.getMin(), chunk.getMax()),
+ maxChunkSizeBytes));
if (splitPoints.empty()) {
LOGV2(21873,
diff --git a/src/mongo/db/s/query_analysis_op_observer.cpp b/src/mongo/db/s/query_analysis_op_observer.cpp
index 84606c7f4c7..658c50c992d 100644
--- a/src/mongo/db/s/query_analysis_op_observer.cpp
+++ b/src/mongo/db/s/query_analysis_op_observer.cpp
@@ -31,6 +31,7 @@
#include "mongo/db/s/query_analysis_coordinator.h"
#include "mongo/db/s/query_analysis_op_observer.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/logv2/log.h"
#include "mongo/s/analyze_shard_key_util.h"
#include "mongo/s/catalog/type_mongos.h"
@@ -88,6 +89,17 @@ void QueryAnalysisOpObserver::onUpdate(OperationContext* opCtx, const OplogUpdat
});
}
}
+
+ if (analyze_shard_key::supportsPersistingSampledQueries() && args.updateArgs->sampleId &&
+ args.updateArgs->preImageDoc && opCtx->writesAreReplicated()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addDiff(*args.updateArgs->sampleId,
+ args.coll->ns(),
+ args.coll->uuid(),
+ *args.updateArgs->preImageDoc,
+ args.updateArgs->updatedDoc)
+ .getAsync([](auto) {});
+ }
}
void QueryAnalysisOpObserver::aboutToDelete(OperationContext* opCtx,
diff --git a/src/mongo/db/s/query_analysis_writer.cpp b/src/mongo/db/s/query_analysis_writer.cpp
new file mode 100644
index 00000000000..19b942e1775
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer.cpp
@@ -0,0 +1,695 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/s/query_analysis_writer.h"
+
+#include "mongo/client/connpool.h"
+#include "mongo/db/catalog/collection_catalog.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/ops/write_ops.h"
+#include "mongo/db/server_options.h"
+#include "mongo/db/update/document_diff_calculator.h"
+#include "mongo/executor/network_interface_factory.h"
+#include "mongo/executor/thread_pool_task_executor.h"
+#include "mongo/logv2/log.h"
+#include "mongo/s/analyze_shard_key_documents_gen.h"
+#include "mongo/s/analyze_shard_key_server_parameters_gen.h"
+#include "mongo/s/write_ops/batched_command_response.h"
+#include "mongo/util/concurrency/thread_pool.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault
+
+namespace mongo {
+namespace analyze_shard_key {
+
+namespace {
+
+MONGO_FAIL_POINT_DEFINE(disableQueryAnalysisWriter);
+MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingLocally);
+MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingRemotely);
+
+const auto getQueryAnalysisWriter = ServiceContext::declareDecoration<QueryAnalysisWriter>();
+
+constexpr int kMaxRetriesOnRetryableErrors = 5;
+const WriteConcernOptions kMajorityWriteConcern{WriteConcernOptions::kMajority,
+ WriteConcernOptions::SyncMode::UNSET,
+ WriteConcernOptions::kWriteConcernTimeoutSystem};
+
+// The size limit for the documents to an insert in a single batch. Leave some padding for other
+// fields in the insert command.
+constexpr int kMaxBSONObjSizeForInsert = BSONObjMaxUserSize - 500 * 1024;
+
+/*
+ * Returns true if this mongod can accept writes to the given collection. Unless the collection is
+ * in the "local" database, this will only return true if this mongod is a primary (or a
+ * standalone).
+ */
+bool canAcceptWrites(OperationContext* opCtx, const NamespaceString& ns) {
+ ShouldNotConflictWithSecondaryBatchApplicationBlock noPBWMBlock(opCtx->lockState());
+ Lock::DBLock lk(opCtx, ns.dbName(), MODE_IS);
+ Lock::CollectionLock lock(opCtx, ns, MODE_IS);
+ return mongo::repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(opCtx,
+ ns.db());
+}
+
+/*
+ * Runs the given write command against the given database locally, asserts that the top-level
+ * command is OK, then asserts the write status using the given 'uassertWriteStatusCb' callback.
+ * Returns the command response.
+ */
+BSONObj executeWriteCommandLocal(OperationContext* opCtx,
+ const std::string dbName,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ DBDirectClient client(opCtx);
+ BSONObj resObj;
+
+ if (!client.runCommand(dbName, cmdObj, resObj)) {
+ uassertStatusOK(getStatusFromCommandResult(resObj));
+ }
+ uassertWriteStatusCb(resObj);
+
+ return resObj;
+}
+
+/*
+ * Runs the given write command against the given database on the (remote) primary, asserts that the
+ * top-level command is OK, then asserts the write status using the given 'uassertWriteStatusCb'
+ * callback. Throws a PrimarySteppedDown error if no primary is found. Returns the command response.
+ */
+BSONObj executeWriteCommandRemote(OperationContext* opCtx,
+ const std::string dbName,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ auto hostAndPort = repl::ReplicationCoordinator::get(opCtx)->getCurrentPrimaryHostAndPort();
+
+ if (hostAndPort.empty()) {
+ uasserted(ErrorCodes::PrimarySteppedDown, "No primary exists currently");
+ }
+
+ auto conn = std::make_unique<ScopedDbConnection>(hostAndPort.toString());
+
+ if (auth::isInternalAuthSet()) {
+ uassertStatusOK(conn->get()->authenticateInternalUser());
+ }
+
+ DBClientBase* client = conn->get();
+ ScopeGuard guard([&] { conn->done(); });
+ try {
+ BSONObj resObj;
+
+ if (!client->runCommand(dbName, cmdObj, resObj)) {
+ uassertStatusOK(getStatusFromCommandResult(resObj));
+ }
+ uassertWriteStatusCb(resObj);
+
+ return resObj;
+ } catch (...) {
+ guard.dismiss();
+ conn->kill();
+ throw;
+ }
+}
+
+/*
+ * Runs the given write command against the given collection. If this mongod is currently the
+ * primary, runs the write command locally. Otherwise, runs the command on the remote primary.
+ * Internally asserts that the top-level command is OK, then asserts the write status using the
+ * given 'uassertWriteStatusCb' callback. Internally retries the write command on retryable errors
+ * (for kMaxRetriesOnRetryableErrors times) so the writes must be idempotent. Returns the
+ * command response.
+ */
+BSONObj executeWriteCommand(OperationContext* opCtx,
+ const NamespaceString& ns,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ const auto dbName = ns.db().toString();
+ auto numRetries = 0;
+
+ while (true) {
+ try {
+ if (canAcceptWrites(opCtx, ns)) {
+ // There is a window here where this mongod may step down after check above. In this
+ // case, a NotWritablePrimary error would be thrown. However, this is preferable to
+ // running the command while holding locks.
+ hangQueryAnalysisWriterBeforeWritingLocally.pauseWhileSet(opCtx);
+ return executeWriteCommandLocal(opCtx, dbName, cmdObj, uassertWriteStatusCb);
+ }
+
+ hangQueryAnalysisWriterBeforeWritingRemotely.pauseWhileSet(opCtx);
+ return executeWriteCommandRemote(opCtx, dbName, cmdObj, uassertWriteStatusCb);
+ } catch (DBException& ex) {
+ if (ErrorCodes::isRetriableError(ex) && numRetries < kMaxRetriesOnRetryableErrors) {
+ numRetries++;
+ continue;
+ }
+ throw;
+ }
+ }
+
+ return {};
+}
+
+struct SampledWriteCommandRequest {
+ UUID sampleId;
+ NamespaceString nss;
+ BSONObj cmd; // the BSON for a {Update,Delete,FindAndModify}CommandRequest
+};
+
+/*
+ * Returns a sampled update command for the update at 'opIndex' in the given update command.
+ */
+SampledWriteCommandRequest makeSampledUpdateCommandRequest(
+ const write_ops::UpdateCommandRequest& originalCmd, int opIndex) {
+ auto op = originalCmd.getUpdates()[opIndex];
+ auto sampleId = op.getSampleId();
+ invariant(sampleId);
+
+ write_ops::UpdateCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)});
+ sampledCmd.setLet(originalCmd.getLet());
+
+ return {*sampleId,
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+/*
+ * Returns a sampled delete command for the delete at 'opIndex' in the given delete command.
+ */
+SampledWriteCommandRequest makeSampledDeleteCommandRequest(
+ const write_ops::DeleteCommandRequest& originalCmd, int opIndex) {
+ auto op = originalCmd.getDeletes()[opIndex];
+ auto sampleId = op.getSampleId();
+ invariant(sampleId);
+
+ write_ops::DeleteCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)});
+ sampledCmd.setLet(originalCmd.getLet());
+
+ return {*sampleId,
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+/*
+ * Returns a sampled findAndModify command for the given findAndModify command.
+ */
+SampledWriteCommandRequest makeSampledFindAndModifyCommandRequest(
+ const write_ops::FindAndModifyCommandRequest& originalCmd) {
+ invariant(originalCmd.getSampleId());
+
+ write_ops::FindAndModifyCommandRequest sampledCmd(originalCmd.getNamespace());
+ sampledCmd.setQuery(originalCmd.getQuery());
+ sampledCmd.setUpdate(originalCmd.getUpdate());
+ sampledCmd.setRemove(originalCmd.getRemove());
+ sampledCmd.setUpsert(originalCmd.getUpsert());
+ sampledCmd.setNew(originalCmd.getNew());
+ sampledCmd.setSort(originalCmd.getSort());
+ sampledCmd.setCollation(originalCmd.getCollation());
+ sampledCmd.setArrayFilters(originalCmd.getArrayFilters());
+ sampledCmd.setLet(originalCmd.getLet());
+ sampledCmd.setSampleId(originalCmd.getSampleId());
+
+ return {*sampledCmd.getSampleId(),
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+} // namespace
+
+QueryAnalysisWriter& QueryAnalysisWriter::get(OperationContext* opCtx) {
+ return get(opCtx->getServiceContext());
+}
+
+QueryAnalysisWriter& QueryAnalysisWriter::get(ServiceContext* serviceContext) {
+ invariant(analyze_shard_key::isFeatureFlagEnabledIgnoreFCV(),
+ "Only support analyzing queries when the feature flag is enabled");
+ invariant(serverGlobalParams.clusterRole == ClusterRole::ShardServer,
+ "Only support analyzing queries on a sharded cluster");
+ return getQueryAnalysisWriter(serviceContext);
+}
+
+void QueryAnalysisWriter::onStartup() {
+ auto serviceContext = getQueryAnalysisWriter.owner(this);
+ auto periodicRunner = serviceContext->getPeriodicRunner();
+ invariant(periodicRunner);
+
+ stdx::lock_guard<Latch> lk(_mutex);
+
+ PeriodicRunner::PeriodicJob QueryWriterJob(
+ "QueryAnalysisQueryWriter",
+ [this](Client* client) {
+ if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) {
+ return;
+ }
+ auto opCtx = client->makeOperationContext();
+ _flushQueries(opCtx.get());
+ },
+ Seconds(gQueryAnalysisWriterIntervalSecs));
+ _periodicQueryWriter = periodicRunner->makeJob(std::move(QueryWriterJob));
+ _periodicQueryWriter.start();
+
+ PeriodicRunner::PeriodicJob diffWriterJob(
+ "QueryAnalysisDiffWriter",
+ [this](Client* client) {
+ if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) {
+ return;
+ }
+ auto opCtx = client->makeOperationContext();
+ _flushDiffs(opCtx.get());
+ },
+ Seconds(gQueryAnalysisWriterIntervalSecs));
+ _periodicDiffWriter = periodicRunner->makeJob(std::move(diffWriterJob));
+ _periodicDiffWriter.start();
+
+ ThreadPool::Options threadPoolOptions;
+ threadPoolOptions.maxThreads = gQueryAnalysisWriterMaxThreadPoolSize;
+ threadPoolOptions.minThreads = gQueryAnalysisWriterMinThreadPoolSize;
+ threadPoolOptions.threadNamePrefix = "QueryAnalysisWriter-";
+ threadPoolOptions.poolName = "QueryAnalysisWriterThreadPool";
+ threadPoolOptions.onCreateThread = [](const std::string& threadName) {
+ Client::initThread(threadName.c_str());
+ };
+ _executor = std::make_shared<executor::ThreadPoolTaskExecutor>(
+ std::make_unique<ThreadPool>(threadPoolOptions),
+ executor::makeNetworkInterface("QueryAnalysisWriterNetwork"));
+ _executor->startup();
+}
+
+void QueryAnalysisWriter::onShutdown() {
+ if (_executor) {
+ _executor->shutdown();
+ _executor->join();
+ }
+ if (_periodicQueryWriter.isValid()) {
+ _periodicQueryWriter.stop();
+ }
+ if (_periodicDiffWriter.isValid()) {
+ _periodicDiffWriter.stop();
+ }
+}
+
+void QueryAnalysisWriter::_flushQueries(OperationContext* opCtx) {
+ try {
+ _flush(opCtx, NamespaceString::kConfigSampledQueriesNamespace, &_queries);
+ } catch (DBException& ex) {
+ LOGV2(7047300,
+ "Failed to flush queries, will try again at the next interval",
+ "error"_attr = redact(ex));
+ }
+}
+
+void QueryAnalysisWriter::_flushDiffs(OperationContext* opCtx) {
+ try {
+ _flush(opCtx, NamespaceString::kConfigSampledQueriesDiffNamespace, &_diffs);
+ } catch (DBException& ex) {
+ LOGV2(7075400,
+ "Failed to flush diffs, will try again at the next interval",
+ "error"_attr = redact(ex));
+ }
+}
+
+void QueryAnalysisWriter::_flush(OperationContext* opCtx,
+ const NamespaceString& ns,
+ Buffer* buffer) {
+ Buffer tmpBuffer;
+ // The indices of invalid documents, e.g. documents that fail to insert with DuplicateKey errors
+ // (i.e. duplicates) and BadValue errors. Such documents should not get added back to the buffer
+ // when the inserts below fail.
+ std::set<int> invalid;
+ {
+ stdx::lock_guard<Latch> lk(_mutex);
+ if (buffer->isEmpty()) {
+ return;
+ }
+ std::swap(tmpBuffer, *buffer);
+ }
+ ScopeGuard backSwapper([&] {
+ stdx::lock_guard<Latch> lk(_mutex);
+ for (int i = 0; i < tmpBuffer.getCount(); i++) {
+ if (invalid.find(i) == invalid.end()) {
+ buffer->add(tmpBuffer.at(i));
+ }
+ }
+ });
+
+ // Insert the documents in batches from the back of the buffer so that we don't need to move the
+ // documents forward after each batch.
+ size_t baseIndex = tmpBuffer.getCount() - 1;
+ size_t maxBatchSize = gQueryAnalysisWriterMaxBatchSize.load();
+
+ while (!tmpBuffer.isEmpty()) {
+ std::vector<BSONObj> docsToInsert;
+ long long objSize = 0;
+
+ size_t lastIndex = tmpBuffer.getCount(); // inclusive
+ while (lastIndex > 0 && docsToInsert.size() < maxBatchSize) {
+ // Check if the next document can fit in the batch.
+ auto doc = tmpBuffer.at(lastIndex - 1);
+ if (doc.objsize() + objSize >= kMaxBSONObjSizeForInsert) {
+ break;
+ }
+ lastIndex--;
+ objSize += doc.objsize();
+ docsToInsert.push_back(std::move(doc));
+ }
+ // We don't add a document that is above the size limit to the buffer so we should have
+ // added at least one document to 'docsToInsert'.
+ invariant(!docsToInsert.empty());
+
+ write_ops::InsertCommandRequest insertCmd(ns);
+ insertCmd.setDocuments(std::move(docsToInsert));
+ insertCmd.setWriteCommandRequestBase([&] {
+ write_ops::WriteCommandRequestBase wcb;
+ wcb.setOrdered(false);
+ wcb.setBypassDocumentValidation(false);
+ return wcb;
+ }());
+ auto insertCmdBson = insertCmd.toBSON(
+ {BSON(WriteConcernOptions::kWriteConcernField << kMajorityWriteConcern.toBSON())});
+
+ executeWriteCommand(opCtx, ns, std::move(insertCmdBson), [&](const BSONObj& resObj) {
+ BatchedCommandResponse response;
+ std::string errMsg;
+
+ if (!response.parseBSON(resObj, &errMsg)) {
+ uasserted(ErrorCodes::FailedToParse, errMsg);
+ }
+
+ if (response.isErrDetailsSet() && response.sizeErrDetails() > 0) {
+ boost::optional<write_ops::WriteError> firstWriteErr;
+
+ for (const auto& err : response.getErrDetails()) {
+ if (err.getStatus() == ErrorCodes::DuplicateKey ||
+ err.getStatus() == ErrorCodes::BadValue) {
+ LOGV2(7075402,
+ "Ignoring insert error",
+ "error"_attr = redact(err.getStatus()));
+ invalid.insert(baseIndex - err.getIndex());
+ continue;
+ }
+ if (!firstWriteErr) {
+ // Save the error for later. Go through the rest of the errors to see if
+ // there are any invalid documents so they can be discarded from the buffer.
+ firstWriteErr.emplace(err);
+ }
+ }
+ if (firstWriteErr) {
+ uassertStatusOK(firstWriteErr->getStatus());
+ }
+ } else {
+ uassertStatusOK(response.toStatus());
+ }
+ });
+
+ tmpBuffer.truncate(lastIndex, objSize);
+ baseIndex -= lastIndex;
+ }
+
+ backSwapper.dismiss();
+}
+
+void QueryAnalysisWriter::Buffer::add(BSONObj doc) {
+ if (doc.objsize() > kMaxBSONObjSizeForInsert) {
+ return;
+ }
+ _docs.push_back(std::move(doc));
+ _numBytes += _docs.back().objsize();
+}
+
+void QueryAnalysisWriter::Buffer::truncate(size_t index, long long numBytes) {
+ invariant(index >= 0);
+ invariant(index < _docs.size());
+ invariant(numBytes > 0);
+ invariant(numBytes <= _numBytes);
+ _docs.resize(index);
+ _numBytes -= numBytes;
+}
+
+bool QueryAnalysisWriter::_exceedsMaxSizeBytes() {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _queries.getSize() + _diffs.getSize() >= gQueryAnalysisWriterMaxMemoryUsageBytes.load();
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addFindQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kFind, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addCountQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kCount, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDistinctQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kDistinct, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addAggregateQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kAggregate, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::_addReadQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ invariant(_executor);
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampleId,
+ nss,
+ cmdName,
+ filter = filter.getOwned(),
+ collation = collation.getOwned()] {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid = CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, nss);
+
+ if (!collUuid) {
+ LOGV2(7047301, "Found a sampled read query for non-existing collection");
+ return;
+ }
+
+ auto cmd = SampledReadCommand{filter.getOwned(), collation.getOwned()};
+ auto doc = SampledReadQueryDocument{sampleId, nss, *collUuid, cmdName, cmd.toBSON()};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss](Status status) {
+ LOGV2(7047302,
+ "Failed to add read query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addUpdateQuery(
+ const write_ops::UpdateCommandRequest& updateCmd, int opIndex) {
+ invariant(updateCmd.getUpdates()[opIndex].getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this, sampledUpdateCmd = makeSampledUpdateCommandRequest(updateCmd, opIndex)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledUpdateCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075300,
+ "Found a sampled update query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledUpdateCmd.sampleId,
+ sampledUpdateCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kUpdate,
+ std::move(sampledUpdateCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = updateCmd.getNamespace()](Status status) {
+ LOGV2(7075301,
+ "Failed to add update query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDeleteQuery(
+ const write_ops::DeleteCommandRequest& deleteCmd, int opIndex) {
+ invariant(deleteCmd.getDeletes()[opIndex].getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this, sampledDeleteCmd = makeSampledDeleteCommandRequest(deleteCmd, opIndex)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledDeleteCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075302,
+ "Found a sampled delete query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledDeleteCmd.sampleId,
+ sampledDeleteCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kDelete,
+ std::move(sampledDeleteCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = deleteCmd.getNamespace()](Status status) {
+ LOGV2(7075303,
+ "Failed to add delete query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addFindAndModifyQuery(
+ const write_ops::FindAndModifyCommandRequest& findAndModifyCmd) {
+ invariant(findAndModifyCmd.getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampledFindAndModifyCmd =
+ makeSampledFindAndModifyCommandRequest(findAndModifyCmd)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledFindAndModifyCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075304,
+ "Found a sampled findAndModify query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledFindAndModifyCmd.sampleId,
+ sampledFindAndModifyCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kFindAndModify,
+ std::move(sampledFindAndModifyCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = findAndModifyCmd.getNamespace()](Status status) {
+ LOGV2(7075305,
+ "Failed to add findAndModify query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDiff(const UUID& sampleId,
+ const NamespaceString& nss,
+ const UUID& collUuid,
+ const BSONObj& preImage,
+ const BSONObj& postImage) {
+ invariant(_executor);
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampleId,
+ nss,
+ collUuid,
+ preImage = preImage.getOwned(),
+ postImage = postImage.getOwned()]() {
+ auto diff = doc_diff::computeInlineDiff(preImage, postImage);
+
+ if (!diff || diff->isEmpty()) {
+ return;
+ }
+
+ auto doc = SampledQueryDiffDocument{sampleId, nss, collUuid, std::move(*diff)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _diffs.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushDiffs(opCtx);
+ }
+ })
+ .onError([this, nss](Status status) {
+ LOGV2(7075401, "Failed to add diff", "ns"_attr = nss, "error"_attr = redact(status));
+ });
+}
+
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/query_analysis_writer.h b/src/mongo/db/s/query_analysis_writer.h
new file mode 100644
index 00000000000..508d4903b65
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer.h
@@ -0,0 +1,214 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/operation_context.h"
+#include "mongo/db/service_context.h"
+#include "mongo/executor/task_executor.h"
+#include "mongo/s/analyze_shard_key_common_gen.h"
+#include "mongo/s/analyze_shard_key_util.h"
+#include "mongo/s/write_ops/batched_command_request.h"
+#include "mongo/util/periodic_runner.h"
+
+namespace mongo {
+namespace analyze_shard_key {
+
+/**
+ * Owns the machinery for persisting sampled queries. That consists of the following:
+ * - The buffer that stores sampled queries and the periodic background job that inserts those
+ * queries into the local config.sampledQueries collection.
+ * - The buffer that stores diffs for sampled update queries and the periodic background job that
+ * inserts those diffs into the local config.sampledQueriesDiff collection.
+ *
+ * Currently, query sampling is only supported on a sharded cluster. So a writer must be a shardsvr
+ * mongod. If the mongod is a primary, it will execute the insert commands locally. If it is a
+ * secondary, it will perform the insert commands against the primary.
+ *
+ * The memory usage of the buffers is controlled by the 'queryAnalysisWriterMaxMemoryUsageBytes'
+ * server parameter. Upon adding a query or a diff that causes the total size of buffers to exceed
+ * the limit, the writer will flush the corresponding buffer immediately instead of waiting for it
+ * to get flushed later by the periodic job.
+ */
+class QueryAnalysisWriter final : public std::enable_shared_from_this<QueryAnalysisWriter> {
+ QueryAnalysisWriter(const QueryAnalysisWriter&) = delete;
+ QueryAnalysisWriter& operator=(const QueryAnalysisWriter&) = delete;
+
+public:
+ /**
+ * Temporarily stores documents to be written to disk.
+ */
+ struct Buffer {
+ public:
+ /**
+ * Adds the given document to the buffer if its size is below the limit (i.e.
+ * BSONObjMaxUserSize - some padding) and increments the total number of bytes accordingly.
+ */
+ void add(BSONObj doc);
+
+ /**
+ * Removes the documents at 'index' onwards from the buffer and decrements the total number
+ * of the bytes by 'numBytes'. The caller must ensure that that 'numBytes' is indeed the
+ * total size of the documents being removed.
+ */
+ void truncate(size_t index, long long numBytes);
+
+ bool isEmpty() const {
+ return _docs.empty();
+ }
+
+ int getCount() const {
+ return _docs.size();
+ }
+
+ long long getSize() const {
+ return _numBytes;
+ }
+
+ BSONObj at(size_t index) const {
+ return _docs[index];
+ }
+
+ private:
+ std::vector<BSONObj> _docs;
+ long long _numBytes = 0;
+ };
+
+ QueryAnalysisWriter() = default;
+ ~QueryAnalysisWriter() = default;
+
+ QueryAnalysisWriter(QueryAnalysisWriter&& source) = delete;
+ QueryAnalysisWriter& operator=(QueryAnalysisWriter&& other) = delete;
+
+ /**
+ * Obtains the service-wide QueryAnalysisWriter instance.
+ */
+ static QueryAnalysisWriter& get(OperationContext* opCtx);
+ static QueryAnalysisWriter& get(ServiceContext* serviceContext);
+
+ void onStartup();
+
+ void onShutdown();
+
+ ExecutorFuture<void> addFindQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addCountQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addDistinctQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addAggregateQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addUpdateQuery(const write_ops::UpdateCommandRequest& updateCmd,
+ int opIndex);
+ ExecutorFuture<void> addDeleteQuery(const write_ops::DeleteCommandRequest& deleteCmd,
+ int opIndex);
+
+ ExecutorFuture<void> addFindAndModifyQuery(
+ const write_ops::FindAndModifyCommandRequest& findAndModifyCmd);
+
+ ExecutorFuture<void> addDiff(const UUID& sampleId,
+ const NamespaceString& nss,
+ const UUID& collUuid,
+ const BSONObj& preImage,
+ const BSONObj& postImage);
+
+ int getQueriesCountForTest() const {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _queries.getCount();
+ }
+
+ void flushQueriesForTest(OperationContext* opCtx) {
+ _flushQueries(opCtx);
+ }
+
+ int getDiffsCountForTest() const {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _diffs.getCount();
+ }
+
+ void flushDiffsForTest(OperationContext* opCtx) {
+ _flushDiffs(opCtx);
+ }
+
+private:
+ ExecutorFuture<void> _addReadQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ void _flushQueries(OperationContext* opCtx);
+ void _flushDiffs(OperationContext* opCtx);
+
+ /**
+ * The helper for '_flushQueries' and '_flushDiffs'. Inserts the documents in 'buffer' into the
+ * collection 'ns' in batches, and removes all the inserted documents from 'buffer'. Internally
+ * retries the inserts on retryable errors for a fixed number of times. Ignores DuplicateKey
+ * errors since they are expected for the following reasons:
+ * - For the query buffer, a sampled query that is idempotent (e.g. a read or retryable write)
+ * could get added to the buffer (across nodes) more than once due to retries.
+ * - For the diff buffer, a sampled multi-update query could end up generating multiple diffs
+ * and each diff is identified using the sample id of the sampled query that creates it.
+ *
+ * Throws an error if the inserts fail with any other error.
+ */
+ void _flush(OperationContext* opCtx, const NamespaceString& nss, Buffer* buffer);
+
+ /**
+ * Returns true if the total size of the buffered queries and diffs has exceeded the maximum
+ * amount of memory that the writer is allowed to use.
+ */
+ bool _exceedsMaxSizeBytes();
+
+ mutable Mutex _mutex = MONGO_MAKE_LATCH("QueryAnalysisWriter::_mutex");
+
+ PeriodicJobAnchor _periodicQueryWriter;
+ Buffer _queries;
+
+ PeriodicJobAnchor _periodicDiffWriter;
+ Buffer _diffs;
+
+ // Initialized on startup and joined on shutdown.
+ std::shared_ptr<executor::TaskExecutor> _executor;
+};
+
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/query_analysis_writer_test.cpp b/src/mongo/db/s/query_analysis_writer_test.cpp
new file mode 100644
index 00000000000..b17f39c82df
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer_test.cpp
@@ -0,0 +1,1281 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/s/query_analysis_writer.h"
+
+#include "mongo/bson/unordered_fields_bsonobj_comparator.h"
+#include "mongo/db/db_raii.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/s/shard_server_test_fixture.h"
+#include "mongo/db/update/document_diff_calculator.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/logv2/log.h"
+#include "mongo/s/analyze_shard_key_documents_gen.h"
+#include "mongo/unittest/death_test.h"
+#include "mongo/util/fail_point.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
+
+namespace mongo {
+namespace analyze_shard_key {
+namespace {
+
+TEST(QueryAnalysisWriterBufferTest, AddBasic) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc0 = BSON("a" << 0);
+ buffer.add(doc0);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize());
+
+ auto doc1 = BSON("a" << BSON_ARRAY(0 << 1 << 2));
+ buffer.add(doc1);
+ ASSERT_EQ(buffer.getCount(), 2);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize() + doc1.objsize());
+
+ ASSERT_BSONOBJ_EQ(buffer.at(0), doc0);
+ ASSERT_BSONOBJ_EQ(buffer.at(1), doc1);
+}
+
+TEST(QueryAnalysisWriterBufferTest, AddTooLarge) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 0);
+ ASSERT_EQ(buffer.getSize(), 0);
+}
+
+TEST(QueryAnalysisWriterBufferTest, TruncateBasic) {
+ auto testTruncateCommon = [](int oldCount, int newCount) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ std::vector<BSONObj> docs;
+ for (auto i = 0; i < oldCount; i++) {
+ docs.push_back(BSON("a" << i));
+ }
+ // The documents have the same size.
+ auto docSize = docs.back().objsize();
+
+ for (const auto& doc : docs) {
+ buffer.add(doc);
+ }
+ ASSERT_EQ(buffer.getCount(), oldCount);
+ ASSERT_EQ(buffer.getSize(), oldCount * docSize);
+
+ buffer.truncate(newCount, (oldCount - newCount) * docSize);
+ ASSERT_EQ(buffer.getCount(), newCount);
+ ASSERT_EQ(buffer.getSize(), newCount * docSize);
+ for (auto i = 0; i < newCount; i++) {
+ ASSERT_BSONOBJ_EQ(buffer.at(i), docs[i]);
+ }
+ };
+
+ testTruncateCommon(10 /* oldCount */, 6 /* newCount */);
+ testTruncateCommon(10 /* oldCount */, 0 /* newCount */); // Truncate all.
+ testTruncateCommon(10 /* oldCount */, 9 /* newCount */); // Truncate one.
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(-1, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(2, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, -doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Zero, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, 0);
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, doc.objsize() * 2);
+}
+
+void assertBsonObjEqualUnordered(const BSONObj& lhs, const BSONObj& rhs) {
+ UnorderedFieldsBSONObjComparator comparator;
+ ASSERT_EQ(comparator.compare(lhs, rhs), 0);
+}
+
+struct QueryAnalysisWriterTest : public ShardServerTestFixture {
+public:
+ void setUp() {
+ ShardServerTestFixture::setUp();
+ QueryAnalysisWriter::get(operationContext()).onStartup();
+
+ DBDirectClient client(operationContext());
+ client.createCollection(nss0.toString());
+ client.createCollection(nss1.toString());
+ }
+
+ void tearDown() {
+ QueryAnalysisWriter::get(operationContext()).onShutdown();
+ ShardServerTestFixture::tearDown();
+ }
+
+protected:
+ UUID getCollectionUUID(const NamespaceString& nss) const {
+ auto collectionCatalog = CollectionCatalog::get(operationContext());
+ return *collectionCatalog->lookupUUIDByNSS(operationContext(), nss);
+ }
+
+ BSONObj makeNonEmptyFilter() {
+ return BSON("_id" << UUID::gen());
+ }
+
+ BSONObj makeNonEmptyCollation() {
+ int strength = rand() % 5 + 1;
+ return BSON("locale"
+ << "en_US"
+ << "strength" << strength);
+ }
+
+ /*
+ * Makes an UpdateCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' updates and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the UpdateCommandRequest and the map storing the expected sampled
+ * UpdateCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::UpdateCommandRequest, std::map<UUID, write_ops::UpdateCommandRequest>>
+ makeUpdateCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::UpdateCommandRequest originalCmd(nss);
+ std::vector<write_ops::UpdateOpEntry> updateOps; // populated below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::UpdateCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto updateOp = write_ops::UpdateOpEntry(
+ BSON(filterFieldName << i),
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << i))));
+ updateOp.setC(BSON("x" << i));
+ updateOp.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << i))});
+ updateOp.setMulti(_getRandomBool());
+ updateOp.setUpsert(_getRandomBool());
+ updateOp.setUpsertSupplied(_getRandomBool());
+ updateOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ updateOp.setSampleId(sampleId);
+
+ write_ops::UpdateCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setUpdates({updateOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ updateOps.push_back(updateOp);
+ }
+ originalCmd.setUpdates(updateOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes an DeleteCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' deletes and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the DeleteCommandRequest and the map storing the expected sampled
+ * DeleteCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::DeleteCommandRequest, std::map<UUID, write_ops::DeleteCommandRequest>>
+ makeDeleteCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::DeleteCommandRequest originalCmd(nss);
+ std::vector<write_ops::DeleteOpEntry> deleteOps; // populated and set below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::DeleteCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto deleteOp =
+ write_ops::DeleteOpEntry(BSON(filterFieldName << i), _getRandomBool() /* multi */);
+ deleteOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ deleteOp.setSampleId(sampleId);
+
+ write_ops::DeleteCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setDeletes({deleteOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ deleteOps.push_back(deleteOp);
+ }
+ originalCmd.setDeletes(deleteOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes a FindAndModifyCommandRequest for the collection 'nss'. The findAndModify is an update
+ * if 'isUpdate' is true, and a remove otherwise. If 'markForSampling' is true, it is marked for
+ * sampling. Returns the FindAndModifyCommandRequest and the map storing the expected sampled
+ * FindAndModifyCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::FindAndModifyCommandRequest,
+ std::map<UUID, write_ops::FindAndModifyCommandRequest>>
+ makeFindAndModifyCommandRequest(const NamespaceString& nss,
+ bool isUpdate,
+ bool markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::FindAndModifyCommandRequest originalCmd(nss);
+ originalCmd.setQuery(BSON(filterFieldName << 0));
+ originalCmd.setUpdate(
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << 0))));
+ originalCmd.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << 10))});
+ originalCmd.setSort(BSON("_id" << 1));
+ if (isUpdate) {
+ originalCmd.setUpsert(_getRandomBool());
+ originalCmd.setNew(_getRandomBool());
+ }
+ originalCmd.setCollation(makeNonEmptyCollation());
+ originalCmd.setLet(let);
+ originalCmd.setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::FindAndModifyCommandRequest> expectedSampledCmds;
+ if (markForSampling) {
+ auto sampleId = UUID::gen();
+ originalCmd.setSampleId(sampleId);
+
+ auto expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setEncryptionInformation(boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ void deleteSampledQueryDocuments() const {
+ DBDirectClient client(operationContext());
+ client.remove(NamespaceString::kConfigSampledQueriesNamespace.toString(), BSONObj());
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueries
+ * collection.
+ */
+ int getSampledQueryDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled read query document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertSampledReadQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledReadQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = SampledReadCommand::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.getFilter(), filter);
+ ASSERT_BSONOBJ_EQ(parsedCmd.getCollation(), collation);
+ }
+
+ /*
+ * Asserts that there is a sampled write query document with the given sample id and that it has
+ * the given fields.
+ */
+ template <typename CommandRequestType>
+ void assertSampledWriteQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledWriteCommandNameEnum cmdName,
+ const CommandRequestType& expectedCmd) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledWriteQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = CommandRequestType::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.toBSON({}), expectedCmd.toBSON({}));
+ }
+
+ /*
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueriesDiff
+ * collection.
+ */
+ int getDiffDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesDiffNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled diff document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertDiffDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& expectedDiff) {
+ auto doc =
+ _getConfigDocument(NamespaceString::kConfigSampledQueriesDiffNamespace, sampleId);
+ auto parsedDiffDoc =
+ SampledQueryDiffDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedDiffDoc.getNs(), nss);
+ ASSERT_EQ(parsedDiffDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedDiffDoc.getSampleId(), sampleId);
+ assertBsonObjEqualUnordered(parsedDiffDoc.getDiff(), expectedDiff);
+ }
+
+ const NamespaceString nss0{"testDb", "testColl0"};
+ const NamespaceString nss1{"testDb", "testColl1"};
+
+ // Test with both empty and non-empty filter and collation to verify that the
+ // QueryAnalysisWriter doesn't require filter or collation to be non-empty.
+ const BSONObj emptyFilter{};
+ const BSONObj emptyCollation{};
+
+ const BSONObj let = BSON("x" << 1);
+ // Test with EncryptionInformation to verify that QueryAnalysisWriter does not persist the
+ // WriteCommandRequestBase fields, especially this sensitive field.
+ const EncryptionInformation encryptionInformation{BSON("foo"
+ << "bar")};
+
+private:
+ bool _getRandomBool() {
+ return rand() % 2 == 0;
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'collNss' in the config collection
+ * 'configNss'.
+ */
+ int _getConfigDocumentsCount(const NamespaceString& configNss,
+ const NamespaceString& collNss) const {
+ DBDirectClient client(operationContext());
+ return client.count(configNss, BSON("ns" << collNss.toString()));
+ }
+
+ /**
+ * Returns the document with the given _id in the config collection 'configNss'.
+ */
+ BSONObj _getConfigDocument(const NamespaceString configNss, const UUID& id) const {
+ DBDirectClient client(operationContext());
+
+ FindCommandRequest findRequest{configNss};
+ findRequest.setFilter(BSON("_id" << id));
+ auto cursor = client.find(std::move(findRequest));
+ ASSERT(cursor->more());
+ return cursor->next();
+ }
+
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", true};
+ FailPointEnableBlock _fp{"disableQueryAnalysisWriter"};
+};
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetIfFeatureFlagNotEnabled, "invariant") {
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey",
+ false};
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnConfigServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::ConfigServer;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnNonShardServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::None;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, NoQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, FindQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testFindCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addFindQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testFindCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testFindCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testFindCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testFindCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, CountQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testCountCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addCountQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kCount, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testCountCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testCountCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testCountCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testCountCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DistinctQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testDistinctCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addDistinctQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kDistinct, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testDistinctCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testDistinctCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testDistinctCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testDistinctCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, AggregateQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testAggregateCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testAggregateCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testAggregateCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testAggregateCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testAggregateCmdCommon(emptyFilter, emptyCollation);
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, UpdateQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeUpdateCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addUpdateQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, UpdateQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0, 3, {0, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, DeleteQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeDeleteCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addDeleteQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, DeleteQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0, 3, {1, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 1).get();
+ writer.addDeleteQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, false /* markForSampling */);
+ writer.addFindAndModifyQuery(originalCmd).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryUpdateMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryRemoveMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, false /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, MultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have one query.
+ auto [originalDeleteCmd, expectedSampledDeleteCmds] =
+ makeDeleteCommandRequest(nss1, 3, {1} /* markForSampling */);
+ ASSERT_EQ(expectedSampledDeleteCmds.size(), 1U);
+ auto [deleteSampleId, expectedSampledDeleteCmd] = *expectedSampledDeleteCmds.begin();
+
+ // Make nss1 have two queries.
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addDeleteQuery(originalDeleteCmd, 1).get();
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addCountQuery(countSampleId, nss1, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(deleteSampleId,
+ expectedSampledDeleteCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledDeleteCmd);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 2);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(countSampleId,
+ nss1,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto findSampleId = UUID::gen();
+ auto originalFindFilter = makeNonEmptyFilter();
+ auto originalFindCollation = makeNonEmptyCollation();
+
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation)
+ .get(); // This is a duplicate.
+ writer.addCountQuery(countSampleId, nss0, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 3);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+ assertSampledReadQueryDocument(countSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numQueries = 5;
+
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = makeNonEmptyFilter();
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numQueries = 3;
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddReadIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto sampleId0 = UUID::gen();
+ auto filter0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+ auto collation0 = makeNonEmptyCollation();
+
+ auto sampleId1 = UUID::gen();
+ auto filter1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+ auto collation1 = makeNonEmptyCollation();
+
+ writer.addFindQuery(sampleId0, nss0, filter0, collation0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addAggregateQuery(sampleId1, nss1, filter1, collation1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId0, nss0, SampledReadCommandNameEnum::kFind, filter0, collation0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledReadQueryDocument(
+ sampleId1, nss1, SampledReadCommandNameEnum::kAggregate, filter1, collation1);
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddUpdateIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0,
+ 3,
+ {0, 2} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDeleteIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0,
+ 3,
+ {0, 1} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addDeleteQuery(originalCmd, 1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddFindAndModifyIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto [originalCmd0, expectedSampledCmds0] = makeFindAndModifyCommandRequest(
+ nss0,
+ true /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId0, expectedSampledCmd0] = *expectedSampledCmds0.begin();
+
+ auto [originalCmd1, expectedSampledCmds1] = makeFindAndModifyCommandRequest(
+ nss1,
+ false /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'b') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId1, expectedSampledCmd1] = *expectedSampledCmds1.begin();
+
+ writer.addFindAndModifyQuery(originalCmd0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addFindAndModifyQuery(originalCmd1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId0,
+ expectedSampledCmd0.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledWriteQueryDocument(sampleId1,
+ expectedSampledCmd1.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd1);
+}
+
+TEST_F(QueryAnalysisWriterTest, AddQueriesBackAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+ auto numQueries = 8;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+
+ // Force the documents to get inserted in three batches of size 3, 3 and 2, respectively.
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 3};
+
+ // Hang after inserting the documents in the first batch.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the second batch to fail so that it falls back to inserting one document at a time in
+ // order, and then force the first and second document in the batch to fail.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 3);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that all the documents other than the ones in the first batch got added back to the
+ // buffer after the error. That is, the error caused the last document in the second batch to
+ // get added to buffer also although it was successfully inserted since the writer did not have
+ // a way to tell if the error caused the entire command to fail early.
+ ASSERT_EQ(writer.getQueriesCountForTest(), 5);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 4);
+
+ // Flush that remaining documents. If the documents were not added back correctly, some
+ // documents would be missing and the checks below would fail.
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, RemoveDuplicatesFromBufferAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+
+ auto numQueries0 = 3;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries0; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries0);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+
+ auto numQueries1 = 5;
+
+ std::vector<UUID> sampleIds1;
+ for (auto i = 0; i < numQueries1; i++) {
+ sampleIds1.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds1[i], nss1, originalFilter, originalCollation).get();
+ // This is a duplicate.
+ if (i < numQueries0) {
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0 + numQueries1);
+
+ // Force the batch to fail so that it falls back to inserting one document at a time in order.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 1);
+
+ // Hang after inserting the first non-duplicate document.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the next non-duplicate document to fail to insert.
+ failFp->setMode(FailPoint::nTimes, 1);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that the duplicate documents did not get added back to the buffer after the error.
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries1);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1 - 1);
+
+ // Flush that remaining documents. If the documents were not added back correctly, the document
+ // that previously failed to insert would be missing and the checks below would fail.
+ failFp->setMode(FailPoint::off, 0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1);
+ for (const auto& sampleId : sampleIds1) {
+ assertSampledReadQueryDocument(
+ sampleId, nss1, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, NoDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsBasic) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON("a" << 1);
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId, nss0, *doc_diff::computeInlineDiff(preImage, postImage));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have a diff for one query.
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0 << "b" << 0 << "c" << 0);
+ auto postImage0 = BSON("a" << 0 << "b" << 1 << "d" << 1);
+
+ // Make nss1 have diffs for two queries.
+ auto collUuid1 = getCollectionUUID(nss1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1) << "d" << BSON("e" << 1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2) << "d" << BSON("e" << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ writer.addDiff(sampleId2, nss1, collUuid1, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 2);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss1, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON("a" << 1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ writer.addDiff(sampleId1, nss0, collUuid0, preImage1, postImage1).get();
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0)
+ .get(); // This is a duplicate.
+ writer.addDiff(sampleId2, nss0, collUuid0, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 3);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ assertDiffDocument(sampleId1, nss0, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss0, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numDiffs = 5;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(("a" + std::to_string(i)) << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numDiffs = 3;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDiffIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+
+ auto collUuid1 = getCollectionUUID(nss1);
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 0);
+ auto postImage1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ // Adding the next diff causes the size to exceed the limit.
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 1);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffEmpty) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 1);
+ auto postImage = preImage;
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ auto postImage = BSONObj();
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+} // namespace
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/range_deleter_service.cpp b/src/mongo/db/s/range_deleter_service.cpp
index acabcc40bea..ca80b24aa3e 100644
--- a/src/mongo/db/s/range_deleter_service.cpp
+++ b/src/mongo/db/s/range_deleter_service.cpp
@@ -86,6 +86,50 @@ RangeDeleterService* RangeDeleterService::get(OperationContext* opCtx) {
return get(opCtx->getServiceContext());
}
+RangeDeleterService::ReadyRangeDeletionsProcessor::ReadyRangeDeletionsProcessor(
+ OperationContext* opCtx)
+ : _thread([this] { _runRangeDeletions(); }) {}
+
+RangeDeleterService::ReadyRangeDeletionsProcessor::~ReadyRangeDeletionsProcessor() {
+ shutdown();
+ invariant(_thread.joinable());
+ _thread.join();
+ invariant(!_threadOpCtxHolder,
+ "Thread operation context is still alive after joining main thread");
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::shutdown() {
+ stdx::lock_guard<Latch> lock(_mutex);
+ if (_state == kStopped)
+ return;
+
+ _state = kStopped;
+
+ if (_threadOpCtxHolder) {
+ stdx::lock_guard<Client> scopedClientLock(*_threadOpCtxHolder->getClient());
+ _threadOpCtxHolder->markKilled(ErrorCodes::Interrupted);
+ }
+}
+
+bool RangeDeleterService::ReadyRangeDeletionsProcessor::_stopRequested() const {
+ stdx::unique_lock<Latch> lock(_mutex);
+ return _state == kStopped;
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::emplaceRangeDeletion(
+ const RangeDeletionTask& rdt) {
+ stdx::unique_lock<Latch> lock(_mutex);
+ invariant(_state == kRunning);
+ _queue.push(rdt);
+ _condVar.notify_all();
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::_completedRangeDeletion() {
+ stdx::unique_lock<Latch> lock(_mutex);
+ dassert(!_queue.empty());
+ _queue.pop();
+}
+
void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
Client::initThread(kRangeDeletionThreadName);
{
@@ -93,21 +137,22 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
cc().setSystemOperationKillableByStepdown(lk);
}
- auto opCtx = [&] {
- stdx::unique_lock<Latch> lock(_mutex);
+ {
+ stdx::lock_guard<Latch> lock(_mutex);
+ if (_state != kRunning) {
+ return;
+ }
_threadOpCtxHolder = cc().makeOperationContext();
- _condVar.notify_all();
- return (*_threadOpCtxHolder).get();
- }();
+ }
- opCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE();
+ auto opCtx = _threadOpCtxHolder.get();
ON_BLOCK_EXIT([this]() {
- stdx::unique_lock<Latch> lock(_mutex);
+ stdx::lock_guard<Latch> lock(_mutex);
_threadOpCtxHolder.reset();
});
- while (opCtx->checkForInterruptNoAssert().isOK()) {
+ while (!_stopRequested()) {
{
stdx::unique_lock<Latch> lock(_mutex);
try {
@@ -223,7 +268,7 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
// Release the thread only in case the operation context has been interrupted, as
// interruption only happens on shutdown/stepdown (this is fine because range
// deletions will be resumed on the next step up)
- if (!opCtx->checkForInterruptNoAssert().isOK()) {
+ if (_stopRequested()) {
break;
}
@@ -237,6 +282,17 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
}
}
+void RangeDeleterService::onStartup(OperationContext* opCtx) {
+ if (disableResumableRangeDeleter.load() ||
+ !feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
+ return;
+ }
+
+ auto opObserverRegistry =
+ checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver());
+ opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>());
+}
+
void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long term) {
if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
return;
@@ -249,22 +305,14 @@ void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long te
return;
}
+ // Wait until all tasks and thread from previous term drain
+ _joinAndResetState();
+
auto lock = _acquireMutexUnconditionally();
dassert(_state == kDown, "Service expected to be down before stepping up");
_state = kInitializing;
- if (_executor) {
- // Join previously shutted down executor before reinstantiating it
- _executor->join();
- _executor.reset();
- } else {
- // Initializing the op observer, only executed once at the first step-up
- auto opObserverRegistry =
- checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver());
- opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>());
- }
-
const std::string kExecName("RangeDeleterServiceExecutor");
auto net = executor::makeNetworkInterface(kExecName);
auto pool = std::make_unique<executor::NetworkInterfaceThreadPool>(net.get());
@@ -303,9 +351,14 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx
LOGV2(6834800, "Resubmitting range deletion tasks");
- ScopedRangeDeleterLock rangeDeleterLock(opCtx);
- DBDirectClient client(opCtx);
+ // The Scoped lock is needed to serialize with concurrent range deletions
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, MODE_S);
+ // The collection lock is needed to serialize with donors trying to
+ // schedule local range deletions by updating the 'pending' field
+ AutoGetCollection rangeDeletionLock(
+ opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S);
+ DBDirectClient client(opCtx);
int nRescheduledTasks = 0;
// (1) register range deletion tasks marked as "processing"
@@ -370,47 +423,55 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx
.semi();
}
-void RangeDeleterService::_stopService(bool joinExecutor) {
- if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
- return;
- }
+void RangeDeleterService::_joinAndResetState() {
+ invariant(_state == kDown);
+ // Join the thread spawned on step-up to resume range deletions
+ _stepUpCompletedFuture.getNoThrow().ignore();
- {
- auto lock = _acquireMutexUnconditionally();
- _state = kDown;
- if (_initOpCtxHolder) {
- stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient());
- _initOpCtxHolder->markKilled(ErrorCodes::Interrupted);
- }
+ // Join and destruct the executor
+ if (_executor) {
+ _executor->join();
+ _executor.reset();
}
- // Join the thread spawned on step-up to resume range deletions
- _stepUpCompletedFuture.getNoThrow().ignore();
+ // Join and destruct the processor
+ _readyRangeDeletionsProcessorPtr.reset();
+
+ // Clear range deletions potentially created during recovery
+ _rangeDeletionTasks.clear();
+}
+void RangeDeleterService::_stopService() {
auto lock = _acquireMutexUnconditionally();
+ if (_state == kDown)
+ return;
+
+ _state = kDown;
+ if (_initOpCtxHolder) {
+ stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient());
+ _initOpCtxHolder->markKilled(ErrorCodes::Interrupted);
+ }
- // It may happen for the `onStepDown` hook to be invoked on a SECONDARY node transitioning
- // to ROLLBACK, hence the executor may have never been initialized
if (_executor) {
_executor->shutdown();
- if (joinExecutor) {
- _executor->join();
- }
}
- // Destroy the range deletion processor in order to stop range deletions
- _readyRangeDeletionsProcessorPtr.reset();
+ // Shutdown the range deletion processor to interrupt range deletions
+ if (_readyRangeDeletionsProcessorPtr) {
+ _readyRangeDeletionsProcessorPtr->shutdown();
+ }
// Clear range deletion tasks map in order to notify potential waiters on completion futures
_rangeDeletionTasks.clear();
}
void RangeDeleterService::onStepDown() {
- _stopService(false /* joinExecutor */);
+ _stopService();
}
void RangeDeleterService::onShutdown() {
- _stopService(true /* joinExecutor */);
+ _stopService();
+ _joinAndResetState();
}
BSONObj RangeDeleterService::dumpState() {
@@ -472,10 +533,9 @@ SharedSemiFuture<void> RangeDeleterService::registerTask(
.then([this, rdt = rdt]() {
// Step 3: schedule the actual range deletion task
auto lock = _acquireMutexUnconditionally();
- invariant(
- _readyRangeDeletionsProcessorPtr || _state == kDown,
- "The range deletions processor must be instantiated if the state != kDown");
if (_state != kDown) {
+ invariant(_readyRangeDeletionsProcessorPtr,
+ "The range deletions processor is not initialized");
_readyRangeDeletionsProcessorPtr->emplaceRangeDeletion(rdt);
}
});
diff --git a/src/mongo/db/s/range_deleter_service.h b/src/mongo/db/s/range_deleter_service.h
index 68cca97454c..c816c8b9db2 100644
--- a/src/mongo/db/s/range_deleter_service.h
+++ b/src/mongo/db/s/range_deleter_service.h
@@ -110,58 +110,37 @@ private:
*/
class ReadyRangeDeletionsProcessor {
public:
- ReadyRangeDeletionsProcessor(OperationContext* opCtx) {
- _thread = stdx::thread([this] { _runRangeDeletions(); });
- stdx::unique_lock<Latch> lock(_mutex);
- opCtx->waitForConditionOrInterrupt(
- _condVar, lock, [&] { return _threadOpCtxHolder.is_initialized(); });
- }
-
- ~ReadyRangeDeletionsProcessor() {
- {
- stdx::unique_lock<Latch> lock(_mutex);
- // The `_threadOpCtxHolder` may have been already reset/interrupted in case the
- // thread got interrupted due to stepdown
- if (_threadOpCtxHolder) {
- stdx::lock_guard<Client> scopedClientLock(*(*_threadOpCtxHolder)->getClient());
- if ((*_threadOpCtxHolder)->checkForInterruptNoAssert().isOK()) {
- (*_threadOpCtxHolder)->markKilled(ErrorCodes::Interrupted);
- }
- }
- _condVar.notify_all();
- }
+ ReadyRangeDeletionsProcessor(OperationContext* opCtx);
+ ~ReadyRangeDeletionsProcessor();
- if (_thread.joinable()) {
- _thread.join();
- }
- }
+ /*
+ * Interrupt ongoing range deletions
+ */
+ void shutdown();
/*
* Schedule a range deletion at the end of the queue
*/
- void emplaceRangeDeletion(const RangeDeletionTask& rdt) {
- stdx::unique_lock<Latch> lock(_mutex);
- _queue.push(rdt);
- _condVar.notify_all();
- }
+ void emplaceRangeDeletion(const RangeDeletionTask& rdt);
private:
/*
+ * Return true if this processor have been shutted down
+ */
+ bool _stopRequested() const;
+
+ /*
* Remove a range deletion from the head of the queue. Supposed to be called only once a
* range deletion successfully finishes.
*/
- void _completedRangeDeletion() {
- stdx::unique_lock<Latch> lock(_mutex);
- dassert(!_queue.empty());
- _queue.pop();
- }
+ void _completedRangeDeletion();
/*
* Code executed by the internal thread
*/
void _runRangeDeletions();
- Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor");
+ mutable Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor");
/*
* Condition variable notified when:
@@ -175,10 +154,13 @@ private:
std::queue<RangeDeletionTask> _queue;
/* Pointer to the (one and only) operation context used by the thread */
- boost::optional<ServiceContext::UniqueOperationContext> _threadOpCtxHolder;
+ ServiceContext::UniqueOperationContext _threadOpCtxHolder;
/* Thread consuming the range deletions queue */
stdx::thread _thread;
+
+ enum State { kRunning, kStopped };
+ State _state{kRunning};
};
// Keeping track of per-collection registered range deletion tasks
@@ -253,6 +235,7 @@ public:
const ChunkRange& range);
/* ReplicaSetAwareServiceShardSvr implemented methods */
+ void onStartup(OperationContext* opCtx) override;
void onStepUpComplete(OperationContext* opCtx, long long term) override;
void onStepDown() override;
void onShutdown() override;
@@ -276,17 +259,21 @@ public:
std::unique_ptr<ReadyRangeDeletionsProcessor> _readyRangeDeletionsProcessorPtr;
private:
+ /* Join all threads and executor and reset the in memory state of the service
+ * Used for onStartUpBegin and on onShutdown
+ */
+ void _joinAndResetState();
+
/* Asynchronously register range deletions on the service. To be called on on step-up */
void _recoverRangeDeletionsOnStepUp(OperationContext* opCtx);
- /* Called by shutdown/stepdown hooks to reset the service */
- void _stopService(bool joinExecutor);
+ /* Called by shutdown/stepdown hooks to interrupt the service */
+ void _stopService();
/* ReplicaSetAwareServiceShardSvr "empty implemented" methods */
- void onStartup(OperationContext* opCtx) override final{};
void onInitialDataAvailable(OperationContext* opCtx,
bool isMajorityDataAvailable) override final {}
- void onStepUpBegin(OperationContext* opCtx, long long term) override final {}
+ void onStepUpBegin(OperationContext* opCtx, long long term) override final{};
void onBecomeArbiter() override final {}
};
diff --git a/src/mongo/db/s/range_deleter_service_op_observer.cpp b/src/mongo/db/s/range_deleter_service_op_observer.cpp
index 9c1e51f9a6f..10bbbb6a924 100644
--- a/src/mongo/db/s/range_deleter_service_op_observer.cpp
+++ b/src/mongo/db/s/range_deleter_service_op_observer.cpp
@@ -35,6 +35,9 @@
#include "mongo/db/s/range_deleter_service.h"
#include "mongo/db/s/range_deletion_task_gen.h"
#include "mongo/db/update/update_oplog_entry_serialization.h"
+#include "mongo/logv2/log.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kShardingRangeDeleter
namespace mongo {
namespace {
@@ -54,10 +57,15 @@ void registerTaskWithOngoingQueriesOnOpLogEntryCommit(OperationContext* opCtx,
(void)RangeDeleterService::get(opCtx)->registerTask(
rdt, std::move(waitForActiveQueriesToComplete));
} catch (const DBException& ex) {
- dassert(ex.code() == ErrorCodes::NotYetInitialized,
- str::stream() << "No error different from `NotYetInitialized` is expected "
- "to be propagated to the range deleter observer. Got error: "
- << ex.toStatus());
+ if (ex.code() != ErrorCodes::NotYetInitialized &&
+ !ErrorCodes::isA<ErrorCategory::NotPrimaryError>(ex.code())) {
+ LOGV2_WARNING(7092800,
+ "No error different from `NotYetInitialized` or `NotPrimaryError` "
+ "category is expected to be propagated to the range deleter "
+ "observer. Range deletion task not registered.",
+ "error"_attr = redact(ex),
+ "task"_attr = rdt);
+ }
}
});
}
diff --git a/src/mongo/db/s/range_deleter_service_test.cpp b/src/mongo/db/s/range_deleter_service_test.cpp
index d2406120483..0807867da2d 100644
--- a/src/mongo/db/s/range_deleter_service_test.cpp
+++ b/src/mongo/db/s/range_deleter_service_test.cpp
@@ -45,6 +45,7 @@ void RangeDeleterServiceTest::setUp() {
ShardServerTestFixture::setUp();
WaitForMajorityService::get(getServiceContext()).startup(getServiceContext());
opCtx = operationContext();
+ RangeDeleterService::get(opCtx)->onStartup(opCtx);
RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L);
RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING();
@@ -275,12 +276,6 @@ TEST_F(RangeDeleterServiceTest, ScheduledTaskInvalidatedOnStepDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- // Re-enable the service for clean teardown
- rds->onStepUpComplete(opCtx, 0L);
- rds->_waitForRangeDeleterServiceUp_FOR_TESTING();
- });
-
try {
completionFuture.get(opCtx);
} catch (const ExceptionForCat<ErrorCategory::Interruption>&) {
@@ -293,12 +288,6 @@ TEST_F(RangeDeleterServiceTest, NoActionPossibleIfServiceIsDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- // Re-enable the service for clean teardown
- rds->onStepUpComplete(opCtx, 0L);
- rds->_waitForRangeDeleterServiceUp_FOR_TESTING();
- });
-
auto taskWithOngoingQueries = createRangeDeletionTaskWithOngoingQueries(
uuidCollA, BSON(kShardKey << 0), BSON(kShardKey << 10), CleanWhenEnum::kDelayed);
@@ -886,10 +875,6 @@ TEST_F(RangeDeleterServiceTest, WaitForOngoingQueriesInvalidatedOnStepDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- rds->onStepUpComplete(opCtx, 0L); // Re-enable the service
- });
-
try {
completionFuture.get(opCtx);
} catch (const ExceptionForCat<ErrorCategory::Interruption>&) {
diff --git a/src/mongo/db/s/range_deletion_util.cpp b/src/mongo/db/s/range_deletion_util.cpp
index 3a8b512ddba..707516db64f 100644
--- a/src/mongo/db/s/range_deletion_util.cpp
+++ b/src/mongo/db/s/range_deletion_util.cpp
@@ -322,7 +322,7 @@ Status deleteRangeInBatches(OperationContext* opCtx,
const ChunkRange& range) {
suspendRangeDeletion.pauseWhileSet(opCtx);
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
bool allDocsRemoved = false;
// Delete all batches in this range unless a stepdown error occurs. Do not yield the
@@ -584,7 +584,7 @@ void persistUpdatedNumOrphans(OperationContext* opCtx,
const auto query = getQueryFilterForRangeDeletionTask(collectionUuid, range);
try {
PersistentTaskStore<RangeDeletionTask> store(NamespaceString::kRangeDeletionNamespace);
- ScopedRangeDeleterLock rangeDeleterLock(opCtx, collectionUuid);
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_IX);
// The DBDirectClient will not retry WriteConflictExceptions internally while holding an X
// mode lock, so we need to retry at this level.
writeConflictRetry(
diff --git a/src/mongo/db/s/shard_server_op_observer.cpp b/src/mongo/db/s/shard_server_op_observer.cpp
index 744c79a1eca..ae059d19512 100644
--- a/src/mongo/db/s/shard_server_op_observer.cpp
+++ b/src/mongo/db/s/shard_server_op_observer.cpp
@@ -509,20 +509,24 @@ void ShardServerOpObserver::onModifyShardedCollectionGlobalIndexCatalogEntry(
IDLParserContext("onModifyShardedCollectionGlobalIndexCatalogEntry"),
indexDoc["entry"].Obj());
auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp();
- opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry](auto _) {
+ auto uuid = uassertStatusOK(
+ UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName]));
+ opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry, uuid](auto _) {
AutoGetCollection autoColl(opCtx, nss, MODE_IX);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, nss, CSRAcquisitionMode::kExclusive)
- ->addIndex(opCtx, indexEntry, indexVersion);
+ ->addIndex(opCtx, indexEntry, {uuid, indexVersion});
});
} else {
auto indexName = indexDoc["entry"][IndexCatalogType::kNameFieldName].str();
auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp();
- opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion](auto _) {
+ auto uuid = uassertStatusOK(
+ UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName]));
+ opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion, uuid](auto _) {
AutoGetCollection autoColl(opCtx, nss, MODE_IX);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, nss, CSRAcquisitionMode::kExclusive)
- ->removeIndex(opCtx, indexName, indexVersion);
+ ->removeIndex(opCtx, indexName, {uuid, indexVersion});
});
}
}
diff --git a/src/mongo/db/s/sharding_recovery_service.cpp b/src/mongo/db/s/sharding_recovery_service.cpp
index d47e4056534..02dab8f3773 100644
--- a/src/mongo/db/s/sharding_recovery_service.cpp
+++ b/src/mongo/db/s/sharding_recovery_service.cpp
@@ -532,7 +532,7 @@ void ShardingRecoveryService::recoverIndexesCatalog(OperationContext* opCtx) {
AutoGetCollection collLock(opCtx, nss, MODE_X);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, collLock->ns(), CSRAcquisitionMode::kExclusive)
- ->addIndex(opCtx, indexEntry, indexVersion);
+ ->addIndex(opCtx, indexEntry, {indexEntry.getCollectionUUID(), indexVersion});
}
}
LOGV2_DEBUG(6686502, 2, "Recovered all index versions");
diff --git a/src/mongo/db/s/sharding_write_router_bm.cpp b/src/mongo/db/s/sharding_write_router_bm.cpp
index e92d8292052..89fa3ded04d 100644
--- a/src/mongo/db/s/sharding_write_router_bm.cpp
+++ b/src/mongo/db/s/sharding_write_router_bm.cpp
@@ -38,6 +38,7 @@
#include "mongo/db/s/collection_metadata.h"
#include "mongo/db/s/collection_sharding_runtime.h"
#include "mongo/db/s/collection_sharding_state_factory_shard.h"
+#include "mongo/db/s/collection_sharding_state_factory_standalone.h"
#include "mongo/db/s/operation_sharding_state.h"
#include "mongo/db/s/sharding_state.h"
#include "mongo/db/s/sharding_write_router.h"
@@ -214,6 +215,10 @@ void BM_UnshardedDestinedRecipient(benchmark::State& state) {
serviceContext->registerClientObserver(std::make_unique<LockerNoopClientObserver>());
const auto opCtx = client->makeOperationContext();
+ CollectionShardingStateFactory::set(
+ opCtx->getServiceContext(),
+ std::make_unique<CollectionShardingStateFactoryStandalone>(opCtx->getServiceContext()));
+
const auto catalogCache = CatalogCacheMock::make();
for (auto keepRunning : state) {
diff --git a/src/mongo/db/s/transaction_coordinator_util.cpp b/src/mongo/db/s/transaction_coordinator_util.cpp
index 736e5fb1cf4..9bd3114a84f 100644
--- a/src/mongo/db/s/transaction_coordinator_util.cpp
+++ b/src/mongo/db/s/transaction_coordinator_util.cpp
@@ -445,7 +445,7 @@ Future<repl::OpTime> persistDecision(txn::AsyncWorkScheduler& scheduler,
// Do not acquire a storage ticket in order to avoid unnecessary serialization
// with other prepared transactions that are holding a storage ticket
// themselves; see SERVER-60682.
- SetTicketAquisitionPriorityForLock setTicketAquisition(
+ SetAdmissionPriorityForLock setTicketAquisition(
opCtx, AdmissionContext::Priority::kImmediate);
getTransactionCoordinatorWorkerCurOpRepository()->set(
opCtx, lsid, txnNumberAndRetryCounter, CoordinatorAction::kWritingDecision);