summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/SConscript1
-rw-r--r--src/mongo/db/catalog/SConscript1
-rw-r--r--src/mongo/db/catalog/collection.h3
-rw-r--r--src/mongo/db/catalog/collection_catalog.cpp110
-rw-r--r--src/mongo/db/catalog/collection_catalog.h9
-rw-r--r--src/mongo/db/catalog/collection_catalog_test.cpp164
-rw-r--r--src/mongo/db/catalog/collection_impl.cpp1
-rw-r--r--src/mongo/db/catalog/create_collection.cpp2
-rw-r--r--src/mongo/db/catalog/create_collection.h2
-rw-r--r--src/mongo/db/catalog/create_collection_test.cpp11
-rw-r--r--src/mongo/db/catalog/database_impl.cpp16
-rw-r--r--src/mongo/db/catalog/virtual_collection_impl.cpp18
-rw-r--r--src/mongo/db/catalog/virtual_collection_impl.h78
-rw-r--r--src/mongo/db/catalog/virtual_collection_options.h27
-rw-r--r--src/mongo/db/change_stream_pre_images_collection_manager.cpp154
-rw-r--r--src/mongo/db/clientcursor.cpp8
-rw-r--r--src/mongo/db/clientcursor.h2
-rw-r--r--src/mongo/db/commands/SConscript4
-rw-r--r--src/mongo/db/commands/count_cmd.cpp10
-rw-r--r--src/mongo/db/commands/dbcheck.cpp195
-rw-r--r--src/mongo/db/commands/dbcommands.cpp8
-rw-r--r--src/mongo/db/commands/dbcommands_d.cpp4
-rw-r--r--src/mongo/db/commands/distinct.cpp11
-rw-r--r--src/mongo/db/commands/external_data_source_commands_test.cpp953
-rw-r--r--src/mongo/db/commands/external_data_source_scope_guard.cpp87
-rw-r--r--src/mongo/db/commands/external_data_source_scope_guard.h84
-rw-r--r--src/mongo/db/commands/fail_point_cmd.cpp4
-rw-r--r--src/mongo/db/commands/find_and_modify.cpp13
-rw-r--r--src/mongo/db/commands/find_cmd.cpp32
-rw-r--r--src/mongo/db/commands/generic.cpp4
-rw-r--r--src/mongo/db/commands/getmore_cmd.cpp13
-rw-r--r--src/mongo/db/commands/parameters.cpp5
-rw-r--r--src/mongo/db/commands/pipeline_command.cpp24
-rw-r--r--src/mongo/db/commands/profile_common.cpp1
-rw-r--r--src/mongo/db/commands/run_aggregate.cpp22
-rw-r--r--src/mongo/db/commands/run_aggregate.h4
-rw-r--r--src/mongo/db/commands/server_status_command.cpp4
-rw-r--r--src/mongo/db/commands/tenant_migration_recipient_cmds.idl2
-rw-r--r--src/mongo/db/concurrency/lock_state.h15
-rw-r--r--src/mongo/db/concurrency/lock_state_test.cpp9
-rw-r--r--src/mongo/db/curop.cpp2
-rw-r--r--src/mongo/db/db_raii.cpp23
-rw-r--r--src/mongo/db/db_raii.h135
-rw-r--r--src/mongo/db/dbdirectclient.h1
-rw-r--r--src/mongo/db/exec/batched_delete_stage.cpp3
-rw-r--r--src/mongo/db/exec/bucket_unpacker.cpp344
-rw-r--r--src/mongo/db/exec/bucket_unpacker.h36
-rw-r--r--src/mongo/db/exec/collection_scan.cpp2
-rw-r--r--src/mongo/db/exec/exclusion_projection_executor.cpp5
-rw-r--r--src/mongo/db/exec/inclusion_projection_executor.cpp5
-rw-r--r--src/mongo/db/exec/projection_node.cpp7
-rw-r--r--src/mongo/db/exec/projection_node.h11
-rw-r--r--src/mongo/db/exec/sbe/SConscript1
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp21
-rw-r--r--src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp4
-rw-r--r--src/mongo/db/exec/update_stage.cpp1
-rw-r--r--src/mongo/db/index_builds_coordinator.cpp4
-rw-r--r--src/mongo/db/mirror_maestro.cpp29
-rw-r--r--src/mongo/db/mongod_main.cpp12
-rw-r--r--src/mongo/db/namespace_string.cpp10
-rw-r--r--src/mongo/db/namespace_string.h6
-rw-r--r--src/mongo/db/operation_context.cpp2
-rw-r--r--src/mongo/db/ops/SConscript1
-rw-r--r--src/mongo/db/ops/update_request.h13
-rw-r--r--src/mongo/db/ops/write_ops.idl15
-rw-r--r--src/mongo/db/ops/write_ops_exec.cpp30
-rw-r--r--src/mongo/db/pipeline/SConscript16
-rw-r--r--src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp211
-rw-r--r--src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h137
-rw-r--r--src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp88
-rw-r--r--src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp90
-rw-r--r--src/mongo/db/pipeline/abt/expr_algebrizer_context.h34
-rw-r--r--src/mongo/db/pipeline/abt/match_expression_visitor.cpp80
-rw-r--r--src/mongo/db/pipeline/abt/utils.cpp35
-rw-r--r--src/mongo/db/pipeline/abt/utils.h10
-rw-r--r--src/mongo/db/pipeline/aggregate_command.idl7
-rw-r--r--src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp1
-rw-r--r--src/mongo/db/pipeline/document_source_check_resume_token_test.cpp4
-rw-r--r--src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp169
-rw-r--r--src/mongo/db/pipeline/document_source_internal_unpack_bucket.h27
-rw-r--r--src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp132
-rw-r--r--src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp138
-rw-r--r--src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp34
-rw-r--r--src/mongo/db/pipeline/expression.cpp2
-rw-r--r--src/mongo/db/pipeline/field_path.cpp14
-rw-r--r--src/mongo/db/pipeline/field_path.h23
-rw-r--r--src/mongo/db/pipeline/pipeline_d.cpp39
-rw-r--r--src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp5
-rw-r--r--src/mongo/db/prepare_conflict_tracker.cpp2
-rw-r--r--src/mongo/db/query/canonical_query.h18
-rw-r--r--src/mongo/db/query/ce/ce_heuristic_test.cpp1
-rw-r--r--src/mongo/db/query/ce/ce_test_utils.cpp4
-rw-r--r--src/mongo/db/query/count_command.idl5
-rw-r--r--src/mongo/db/query/cqf_command_utils.cpp2
-rw-r--r--src/mongo/db/query/cqf_get_executor.cpp6
-rw-r--r--src/mongo/db/query/distinct_command.idl6
-rw-r--r--src/mongo/db/query/find_command.idl5
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.h67
-rw-r--r--src/mongo/db/query/fle/query_rewriter.cpp29
-rw-r--r--src/mongo/db/query/fle/query_rewriter_test.cpp192
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp128
-rw-r--r--src/mongo/db/query/fle/range_predicate.h6
-rw-r--r--src/mongo/db/query/fle/range_predicate_test.cpp154
-rw-r--r--src/mongo/db/query/get_executor.cpp280
-rw-r--r--src/mongo/db/query/get_executor.h16
-rw-r--r--src/mongo/db/query/optimizer/SConscript13
-rw-r--r--src/mongo/db/query/optimizer/cascades/cost_derivation.cpp446
-rw-r--r--src/mongo/db/query/optimizer/cascades/implementers.cpp34
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp5
-rw-r--r--src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp42
-rw-r--r--src/mongo/db/query/optimizer/cascades/memo_defs.h4
-rw-r--r--src/mongo/db/query/optimizer/explain.cpp28
-rw-r--r--src/mongo/db/query/optimizer/explain.h5
-rw-r--r--src/mongo/db/query/optimizer/interval_intersection_test.cpp12
-rw-r--r--src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp120
-rw-r--r--src/mongo/db/query/optimizer/node.cpp22
-rw-r--r--src/mongo/db/query/optimizer/node.h13
-rw-r--r--src/mongo/db/query/optimizer/opt_phase_manager.cpp17
-rw-r--r--src/mongo/db/query/optimizer/opt_phase_manager.h5
-rw-r--r--src/mongo/db/query/optimizer/optimizer_failure_test.cpp48
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp44
-rw-r--r--src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp245
-rw-r--r--src/mongo/db/query/optimizer/props.cpp15
-rw-r--r--src/mongo/db/query/optimizer/props.h7
-rw-r--r--src/mongo/db/query/optimizer/utils/abt_hash.cpp2
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp12
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_utils.cpp206
-rw-r--r--src/mongo/db/query/optimizer/utils/unit_test_utils.h64
-rw-r--r--src/mongo/db/query/optimizer/utils/utils.cpp11
-rw-r--r--src/mongo/db/query/optimizer/utils/utils.h1
-rw-r--r--src/mongo/db/query/parsed_distinct.cpp3
-rw-r--r--src/mongo/db/query/parsed_distinct.h15
-rw-r--r--src/mongo/db/query/plan_cache_key_factory.cpp2
-rw-r--r--src/mongo/db/query/query_knobs.idl9
-rw-r--r--src/mongo/db/query/query_planner.cpp39
-rw-r--r--src/mongo/db/query/query_solution.h18
-rw-r--r--src/mongo/db/query/sbe_shard_filter_test.cpp2
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp1334
-rw-r--r--src/mongo/db/query/sbe_stage_builder.h276
-rw-r--r--src/mongo/db/query/sbe_stage_builder_coll_scan.cpp75
-rw-r--r--src/mongo/db/query/sbe_stage_builder_coll_scan.h6
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp121
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.h5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp206
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h87
-rw-r--r--src/mongo/db/query/sbe_stage_builder_index_scan.cpp39
-rw-r--r--src/mongo/db/query/sbe_stage_builder_index_scan.h50
-rw-r--r--src/mongo/db/query/sbe_stage_builder_lookup.cpp26
-rw-r--r--src/mongo/db/repl/SConscript2
-rw-r--r--src/mongo/db/repl/dbcheck.cpp31
-rw-r--r--src/mongo/db/repl/dbcheck.idl36
-rw-r--r--src/mongo/db/repl/noop_writer.cpp6
-rw-r--r--src/mongo/db/repl/oplog_applier_impl.cpp21
-rw-r--r--src/mongo/db/repl/repl_set_request_votes.cpp6
-rw-r--r--src/mongo/db/repl/replication_consistency_markers_impl.cpp11
-rw-r--r--src/mongo/db/repl/replication_coordinator_external_state_impl.cpp2
-rw-r--r--src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp3
-rw-r--r--src/mongo/db/repl/replication_coordinator_test_fixture.cpp2
-rw-r--r--src/mongo/db/repl/replication_info.cpp4
-rw-r--r--src/mongo/db/repl/storage_interface_impl.cpp8
-rw-r--r--src/mongo/db/repl/tenant_migration_donor_service.cpp166
-rw-r--r--src/mongo/db/repl/tenant_migration_donor_service.h9
-rw-r--r--src/mongo/db/repl/tenant_migration_recipient_service.cpp117
-rw-r--r--src/mongo/db/repl/tenant_migration_recipient_service.h17
-rw-r--r--src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp593
-rw-r--r--src/mongo/db/s/SConscript37
-rw-r--r--src/mongo/db/s/balancer/balancer.cpp11
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp19
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h7
-rw-r--r--src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp6
-rw-r--r--src/mongo/db/s/balancer/balancer_policy.cpp31
-rw-r--r--src/mongo/db/s/balancer/balancer_policy.h5
-rw-r--r--src/mongo/db/s/balancer_stats_registry.cpp28
-rw-r--r--src/mongo/db/s/balancer_stats_registry.h15
-rw-r--r--src/mongo/db/s/cluster_count_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_find_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_pipeline_cmd_d.cpp5
-rw-r--r--src/mongo/db/s/cluster_write_cmd_d.cpp14
-rw-r--r--src/mongo/db/s/collection_sharding_runtime.cpp16
-rw-r--r--src/mongo/db/s/collection_sharding_runtime.h8
-rw-r--r--src/mongo/db/s/collection_sharding_runtime_test.cpp9
-rw-r--r--src/mongo/db/s/config/sharding_catalog_manager.h3
-rw-r--r--src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp28
-rw-r--r--src/mongo/db/s/query_analysis_op_observer.cpp12
-rw-r--r--src/mongo/db/s/query_analysis_writer.cpp695
-rw-r--r--src/mongo/db/s/query_analysis_writer.h214
-rw-r--r--src/mongo/db/s/query_analysis_writer_test.cpp1281
-rw-r--r--src/mongo/db/s/range_deleter_service.cpp154
-rw-r--r--src/mongo/db/s/range_deleter_service.h67
-rw-r--r--src/mongo/db/s/range_deleter_service_op_observer.cpp16
-rw-r--r--src/mongo/db/s/range_deleter_service_test.cpp17
-rw-r--r--src/mongo/db/s/range_deletion_util.cpp4
-rw-r--r--src/mongo/db/s/shard_server_op_observer.cpp12
-rw-r--r--src/mongo/db/s/sharding_recovery_service.cpp2
-rw-r--r--src/mongo/db/s/sharding_write_router_bm.cpp5
-rw-r--r--src/mongo/db/s/transaction_coordinator_util.cpp2
-rw-r--r--src/mongo/db/server_feature_flags.idl5
-rw-r--r--src/mongo/db/service_context.cpp2
-rw-r--r--src/mongo/db/service_entry_point_common.cpp39
-rw-r--r--src/mongo/db/stats/SConscript1
-rw-r--r--src/mongo/db/storage/external_record_store.cpp5
-rw-r--r--src/mongo/db/storage/external_record_store.h8
-rw-r--r--src/mongo/db/storage/external_record_store_test.cpp122
-rw-r--r--src/mongo/db/storage/input_object.h2
-rw-r--r--src/mongo/db/storage/input_stream.h8
-rw-r--r--src/mongo/db/storage/io_error_message.h43
-rw-r--r--src/mongo/db/storage/multi_bson_stream_cursor.cpp27
-rw-r--r--src/mongo/db/storage/multi_bson_stream_cursor.h7
-rw-r--r--src/mongo/db/storage/named_pipe.h22
-rw-r--r--src/mongo/db/storage/named_pipe_posix.cpp23
-rw-r--r--src/mongo/db/storage/named_pipe_windows.cpp32
-rw-r--r--src/mongo/db/storage/oplog_cap_maintainer_thread.cpp3
-rw-r--r--src/mongo/db/storage/recovery_unit.h32
-rw-r--r--src/mongo/db/storage/storage_engine_impl.cpp10
-rw-r--r--src/mongo/db/storage/storage_stats.h (renamed from src/mongo/db/query/optimizer/cascades/cost_derivation.h)31
-rw-r--r--src/mongo/db/storage/wiredtiger/SConscript2
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp17
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp28
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp36
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h5
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp139
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h55
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp215
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp115
-rw-r--r--src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h35
-rw-r--r--src/mongo/db/timeseries/bucket_catalog.cpp6
-rw-r--r--src/mongo/db/timeseries/bucket_catalog_test.cpp67
-rw-r--r--src/mongo/db/transaction/transaction_api.cpp21
-rw-r--r--src/mongo/db/transaction/transaction_metrics_observer.cpp2
-rw-r--r--src/mongo/db/ttl.cpp2
230 files changed, 10213 insertions, 3510 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index 06f719c69ac..0f59f24014c 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -2337,6 +2337,7 @@ env.Library(
'$BUILD_DIR/mongo/db/change_streams_cluster_parameter',
'$BUILD_DIR/mongo/db/pipeline/change_stream_expired_pre_image_remover',
'$BUILD_DIR/mongo/db/query/ce/query_ce_histogram',
+ '$BUILD_DIR/mongo/db/s/query_analysis_writer',
'$BUILD_DIR/mongo/db/set_change_stream_state_coordinator',
'$BUILD_DIR/mongo/idl/cluster_server_parameter',
'$BUILD_DIR/mongo/idl/cluster_server_parameter_op_observer',
diff --git a/src/mongo/db/catalog/SConscript b/src/mongo/db/catalog/SConscript
index 7d495903c0f..302d3782c72 100644
--- a/src/mongo/db/catalog/SConscript
+++ b/src/mongo/db/catalog/SConscript
@@ -508,6 +508,7 @@ env.Library(
env.Library(
target='catalog_helpers',
source=[
+ "$BUILD_DIR/mongo/db/commands/external_data_source_scope_guard.cpp",
'capped_utils.cpp',
'coll_mod_index.cpp',
'coll_mod.cpp',
diff --git a/src/mongo/db/catalog/collection.h b/src/mongo/db/catalog/collection.h
index 3fac44cff7a..76905a8fe6f 100644
--- a/src/mongo/db/catalog/collection.h
+++ b/src/mongo/db/catalog/collection.h
@@ -67,6 +67,9 @@ struct CollectionUpdateArgs {
std::vector<StmtId> stmtIds = {kUninitializedStmtId};
+ // The unique sample id for this update if it has been chosen for sampling.
+ boost::optional<UUID> sampleId;
+
// The document before modifiers were applied.
boost::optional<BSONObj> preImageDoc;
diff --git a/src/mongo/db/catalog/collection_catalog.cpp b/src/mongo/db/catalog/collection_catalog.cpp
index 1d644b93e69..735a4b897b0 100644
--- a/src/mongo/db/catalog/collection_catalog.cpp
+++ b/src/mongo/db/catalog/collection_catalog.cpp
@@ -52,6 +52,13 @@
namespace mongo {
namespace {
+// Sentinel id for marking a catalogId mapping range as unknown. Must use an invalid RecordId.
+static RecordId kUnknownRangeMarkerId = RecordId::minLong();
+// Maximum number of entries in catalogId mapping when inserting catalogId missing at timestamp.
+// Used to avoid quadratic behavior when inserting entries at the beginning. When threshold is
+// reached we will fall back to more durable catalog scans.
+static constexpr int kMaxCatalogIdMappingLengthForMissingInsert = 1000;
+
struct LatestCollectionCatalog {
std::shared_ptr<CollectionCatalog> catalog = std::make_shared<CollectionCatalog>();
};
@@ -1288,7 +1295,11 @@ CollectionCatalog::CatalogIdLookup CollectionCatalog::lookupCatalogIdByNSS(
// iterator to get the last entry where the time is less or equal.
auto catalogId = (--rangeIt)->id;
if (catalogId) {
- return {*catalogId, CatalogIdLookup::NamespaceExistence::kExists};
+ if (*catalogId != kUnknownRangeMarkerId) {
+ return {*catalogId, CatalogIdLookup::NamespaceExistence::kExists};
+ } else {
+ return {RecordId{}, CatalogIdLookup::NamespaceExistence::kUnknown};
+ }
}
return {RecordId{}, CatalogIdLookup::NamespaceExistence::kNotExists};
}
@@ -1615,12 +1626,15 @@ std::shared_ptr<Collection> CollectionCatalog::deregisterCollection(
// TODO SERVER-68674: Remove feature flag check.
if (feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV() && isDropPending) {
- auto ident = coll->getSharedIdent()->getIdent();
- LOGV2_DEBUG(6825300, 1, "Registering drop pending collection ident", "ident"_attr = ident);
+ if (auto sharedIdent = coll->getSharedIdent(); sharedIdent) {
+ auto ident = sharedIdent->getIdent();
+ LOGV2_DEBUG(
+ 6825300, 1, "Registering drop pending collection ident", "ident"_attr = ident);
- auto it = _dropPendingCollection.find(ident);
- invariant(it == _dropPendingCollection.end());
- _dropPendingCollection[ident] = coll;
+ auto it = _dropPendingCollection.find(ident);
+ invariant(it == _dropPendingCollection.end());
+ _dropPendingCollection[ident] = coll;
+ }
}
_orderedCollections.erase(dbIdPair);
@@ -1735,7 +1749,8 @@ void CollectionCatalog::_pushCatalogIdForNSS(const NamespaceString& nss,
return;
}
- // Re-write latest entry if timestamp match (multiple changes occured in this transaction)
+ // An entry could exist already if concurrent writes are performed, keep the latest change in
+ // that case.
if (!ids.empty() && ids.back().ts == *ts) {
ids.back().id = catalogId;
return;
@@ -1771,8 +1786,8 @@ void CollectionCatalog::_pushCatalogIdForRename(const NamespaceString& from,
auto& fromIds = _catalogIds.at(from);
invariant(!fromIds.empty());
- // Re-write latest entry if timestamp match (multiple changes occured in this transaction),
- // otherwise push at end
+ // An entry could exist already if concurrent writes are performed, keep the latest change in
+ // that case.
if (!toIds.empty() && toIds.back().ts == *ts) {
toIds.back().id = fromIds.back().id;
} else {
@@ -1792,6 +1807,83 @@ void CollectionCatalog::_pushCatalogIdForRename(const NamespaceString& from,
}
}
+void CollectionCatalog::_insertCatalogIdForNSSAfterScan(const NamespaceString& nss,
+ boost::optional<RecordId> catalogId,
+ Timestamp ts) {
+ // TODO SERVER-68674: Remove feature flag check.
+ if (!feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV()) {
+ // No-op.
+ return;
+ }
+
+ auto& ids = _catalogIds[nss];
+
+ // Binary search for to the entry with same or larger timestamp
+ auto it =
+ std::lower_bound(ids.begin(), ids.end(), ts, [](const auto& entry, const Timestamp& ts) {
+ return entry.ts < ts;
+ });
+
+ // The logic of what we need to do differs whether we are inserting a valid catalogId or not.
+ if (catalogId) {
+ if (it != ids.end()) {
+ // An entry could exist already if concurrent writes are performed, keep the latest
+ // change in that case.
+ if (it->ts == ts) {
+ it->id = catalogId;
+ return;
+ }
+
+ // If next element has same catalogId, we can adjust its timestamp to cover a longer
+ // range
+ if (it->id == catalogId) {
+ it->ts = ts;
+ _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids);
+ return;
+ }
+ }
+
+ // Otherwise insert new entry at timestamp
+ ids.insert(it, TimestampedCatalogId{catalogId, ts});
+ _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids);
+ return;
+ }
+
+ // Avoid inserting missing mapping when the list has grown past the threshold. Will cause the
+ // system to fall back to scanning the durable catalog.
+ if (ids.size() >= kMaxCatalogIdMappingLengthForMissingInsert) {
+ return;
+ }
+
+ if (it != ids.end() && it->ts == ts) {
+ // An entry could exist already if concurrent writes are performed, keep the latest change
+ // in that case.
+ it->id = boost::none;
+ } else {
+ // Otherwise insert new entry
+ it = ids.insert(it, TimestampedCatalogId{boost::none, ts});
+ }
+
+ // The iterator is positioned on the added/modified element above, reposition it to the next
+ // entry
+ ++it;
+
+ // We don't want to assume that the namespace remains not existing until the next entry, as
+ // there can be times where the namespace actually does exist. To make sure we trigger the
+ // scanning of the durable catalog in this range we will insert a bogus entry using an invalid
+ // RecordId at the next timestamp. This will treat the range forward as unknown.
+ auto nextTs = ts + 1;
+
+ // If the next entry is on the next timestamp already, we can skip adding the bogus entry. If
+ // this function is called for a previously unknown namespace, we may not have any future valid
+ // entries and the iterator would be positioned at and at this point.
+ if (it == ids.end() || it->ts != nextTs) {
+ ids.insert(it, TimestampedCatalogId{kUnknownRangeMarkerId, nextTs});
+ }
+
+ _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids);
+}
+
void CollectionCatalog::_markNamespaceForCatalogIdCleanupIfNeeded(
const NamespaceString& nss, const std::vector<TimestampedCatalogId>& ids) {
diff --git a/src/mongo/db/catalog/collection_catalog.h b/src/mongo/db/catalog/collection_catalog.h
index d3d3446f1be..0b2561f0996 100644
--- a/src/mongo/db/catalog/collection_catalog.h
+++ b/src/mongo/db/catalog/collection_catalog.h
@@ -736,6 +736,15 @@ private:
const NamespaceString& to,
boost::optional<Timestamp> ts);
+ // TODO SERVER-70150: Make private again
+public:
+ // Inserts a catalogId for namespace at given Timestamp. Used after scanning the durable catalog
+ // for a correct mapping at the given timestamp.
+ void _insertCatalogIdForNSSAfterScan(const NamespaceString& nss,
+ boost::optional<RecordId> catalogId,
+ Timestamp ts);
+
+private:
// Helper to calculate if a namespace needs to be marked for cleanup for a set of timestamped
// catalogIds
void _markNamespaceForCatalogIdCleanupIfNeeded(const NamespaceString& nss,
diff --git a/src/mongo/db/catalog/collection_catalog_test.cpp b/src/mongo/db/catalog/collection_catalog_test.cpp
index 6a554f8fe59..8710d807f73 100644
--- a/src/mongo/db/catalog/collection_catalog_test.cpp
+++ b/src/mongo/db/catalog/collection_catalog_test.cpp
@@ -2114,6 +2114,170 @@ TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingRollback) {
CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists);
}
+TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingInsert) {
+ RAIIServerParameterControllerForTest featureFlagController(
+ "featureFlagPointInTimeCatalogLookups", true);
+
+ NamespaceString nss("a.b");
+
+ // Create a collection on the namespace
+ createCollection(opCtx.get(), nss, Timestamp(1, 10));
+ dropCollection(opCtx.get(), nss, Timestamp(1, 20));
+ createCollection(opCtx.get(), nss, Timestamp(1, 30));
+
+ auto rid1 = catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 10)).id;
+ auto rid2 = catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id;
+
+ // Simulate startup where we have a range [oldest, stable] by creating and dropping collections
+ // and then advancing the oldest timestamp and then reading behind it.
+ CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) {
+ catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 40));
+ });
+
+ // Confirm that the mappings have been cleaned up
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+
+ // TODO SERVER-70150: Use openCollection
+ CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) {
+ catalog._insertCatalogIdForNSSAfterScan(nss, rid1, Timestamp(1, 17));
+ });
+
+ // Lookups before the inserted timestamp is still unknown
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 11)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+
+ // Lookups at or after the inserted timestamp is found, even if they don't match with WT
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).id, rid1);
+ // The entry at Timestamp(1, 30) is unaffected
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2);
+
+ // TODO SERVER-70150: Use openCollection
+ CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) {
+ catalog._insertCatalogIdForNSSAfterScan(nss, rid1, Timestamp(1, 12));
+ });
+
+ // We should now have extended the range from Timestamp(1, 17) to Timestamp(1, 12)
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 12)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 12)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 16)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 16)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).id, rid1);
+ // The entry at Timestamp(1, 30) is unaffected
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2);
+
+ // TODO SERVER-70150: Use openCollection
+ CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) {
+ catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 25));
+ });
+
+ // Check the entries, most didn't change
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).id, rid1);
+ // At Timestamp(1, 25) we now return kNotExists
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists);
+ // But next timestamp returns unknown
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 26)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+ // The entry at Timestamp(1, 30) is unaffected
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2);
+
+ // TODO SERVER-70150: Use openCollection
+ CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) {
+ catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 26));
+ });
+
+ // We should not have re-written the existing entry at Timestamp(1, 26)
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).id, rid1);
+ // At Timestamp(1, 25) we now return kNotExists
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists);
+ // But next timestamp returns unknown
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 26)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 27)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+ // The entry at Timestamp(1, 30) is unaffected
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists);
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2);
+
+ // Clean up, check so we are back to the original state
+ CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) {
+ catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 41));
+ });
+
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+}
+
+TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingInsertUnknown) {
+ RAIIServerParameterControllerForTest featureFlagController(
+ "featureFlagPointInTimeCatalogLookups", true);
+
+ NamespaceString nss("a.b");
+
+ // Simulate startup where we have a range [oldest, stable] by advancing the oldest timestamp and
+ // then reading behind it.
+ CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) {
+ catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 40));
+ });
+
+ // Reading before the oldest is unknown
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown);
+
+ // Try to instantiate a non existing collection at this timestamp.
+ // TODO SERVER-70150: Use openCollection
+ CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) {
+ catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 15));
+ });
+
+ // Lookup should now be not existing
+ ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result,
+ CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists);
+}
+
TEST_F(CollectionCatalogTimestampTest, CollectionLifetimeTiedToStorageTransactionLifetime) {
RAIIServerParameterControllerForTest featureFlagController(
"featureFlagPointInTimeCatalogLookups", true);
diff --git a/src/mongo/db/catalog/collection_impl.cpp b/src/mongo/db/catalog/collection_impl.cpp
index b4f3078a3c9..c028252f3b1 100644
--- a/src/mongo/db/catalog/collection_impl.cpp
+++ b/src/mongo/db/catalog/collection_impl.cpp
@@ -462,6 +462,7 @@ Status CollectionImpl::initFromExisting(OperationContext* opCtx,
// Update the idents for the newly initialized indexes.
for (const auto& sharedIdent : sharedIdents) {
auto desc = getIndexCatalog()->findIndexByName(opCtx, sharedIdent.first);
+ invariant(desc);
auto entry = getIndexCatalog()->getEntryShared(desc);
entry->setIdent(sharedIdent.second);
}
diff --git a/src/mongo/db/catalog/create_collection.cpp b/src/mongo/db/catalog/create_collection.cpp
index 79fecf84cd7..69d71d25468 100644
--- a/src/mongo/db/catalog/create_collection.cpp
+++ b/src/mongo/db/catalog/create_collection.cpp
@@ -409,7 +409,7 @@ Status _createTimeseries(OperationContext* opCtx,
AutoGetCollection::Options{}.viewMode(auto_get_collection::ViewMode::kViewsPermitted));
Lock::CollectionLock systemDotViewsLock(
opCtx,
- NamespaceString(ns.db(), NamespaceString::kSystemDotViewsCollectionName),
+ NamespaceString(ns.dbName(), NamespaceString::kSystemDotViewsCollectionName),
MODE_X);
auto db = autoColl.ensureDbExists(opCtx);
diff --git a/src/mongo/db/catalog/create_collection.h b/src/mongo/db/catalog/create_collection.h
index 08754d89daa..655644c4e36 100644
--- a/src/mongo/db/catalog/create_collection.h
+++ b/src/mongo/db/catalog/create_collection.h
@@ -27,6 +27,8 @@
* it in the license file.
*/
+#pragma once
+
#include <boost/optional.hpp>
#include <string>
diff --git a/src/mongo/db/catalog/create_collection_test.cpp b/src/mongo/db/catalog/create_collection_test.cpp
index 73c724585f3..97d1384f40f 100644
--- a/src/mongo/db/catalog/create_collection_test.cpp
+++ b/src/mongo/db/catalog/create_collection_test.cpp
@@ -324,10 +324,8 @@ TEST_F(CreateCollectionTest, ValidationDisabledForTemporaryReshardingCollection)
ASSERT_OK(status);
}
-static const std::string kValidUrl1 =
- ExternalDataSourceMetadata::kDefaultFileUrlPrefix + "named_pipe1";
-static const std::string kValidUrl2 =
- ExternalDataSourceMetadata::kDefaultFileUrlPrefix + "named_pipe2";
+const auto kValidUrl1 = ExternalDataSourceMetadata::kUrlProtocolFile + "named_pipe1"s;
+const auto kValidUrl2 = ExternalDataSourceMetadata::kUrlProtocolFile + "named_pipe2"s;
TEST_F(CreateVirtualCollectionTest, VirtualCollectionOptionsWithOneSource) {
NamespaceString vcollNss("myDb", "vcoll.name");
@@ -344,8 +342,7 @@ TEST_F(CreateVirtualCollectionTest, VirtualCollectionOptionsWithOneSource) {
auto vcollOpts = getVirtualCollectionOptions(opCtx.get(), vcollNss);
ASSERT_EQ(vcollOpts.dataSources.size(), 1);
- ASSERT_EQ(ExternalDataSourceMetadata::kUrlProtocolFile + vcollOpts.dataSources[0].url,
- kValidUrl1);
+ ASSERT_EQ(vcollOpts.dataSources[0].url, kValidUrl1);
ASSERT_EQ(stdx::to_underlying(vcollOpts.dataSources[0].storageType),
stdx::to_underlying(StorageTypeEnum::pipe));
ASSERT_EQ(stdx::to_underlying(vcollOpts.dataSources[0].fileType),
@@ -389,7 +386,7 @@ TEST_F(CreateVirtualCollectionTest, InvalidVirtualCollectionOptions) {
{
bool exceptionOccurred = false;
VirtualCollectionOptions reqVcollOpts;
- constexpr auto kInvalidUrl = "file:///abc/named_pipe";
+ constexpr auto kInvalidUrl = "fff://abc/named_pipe"_sd;
try {
reqVcollOpts.dataSources.emplace_back(
kInvalidUrl, StorageTypeEnum::pipe, FileTypeEnum::bson);
diff --git a/src/mongo/db/catalog/database_impl.cpp b/src/mongo/db/catalog/database_impl.cpp
index 618ecedc65a..0a3c9754e16 100644
--- a/src/mongo/db/catalog/database_impl.cpp
+++ b/src/mongo/db/catalog/database_impl.cpp
@@ -573,20 +573,22 @@ Status DatabaseImpl::_finishDropCollection(OperationContext* opCtx,
"namespace"_attr = nss,
"uuid"_attr = uuid);
- auto status = catalog::dropCollection(
- opCtx, collection->ns(), collection->getCatalogId(), collection->getSharedIdent());
- if (!status.isOK())
- return status;
+ // A virtual collection does not have a durable catalog entry.
+ if (auto sharedIdent = collection->getSharedIdent()) {
+ auto status = catalog::dropCollection(
+ opCtx, collection->ns(), collection->getCatalogId(), sharedIdent);
+ if (!status.isOK())
+ return status;
- opCtx->recoveryUnit()->onCommit(
- [opCtx, nss, uuid, ident = collection->getSharedIdent()->getIdent()](
- boost::optional<Timestamp> commitTime) {
+ opCtx->recoveryUnit()->onCommit([opCtx, nss, uuid, ident = sharedIdent->getIdent()](
+ boost::optional<Timestamp> commitTime) {
if (!commitTime) {
return;
}
HistoricalIdentTracker::get(opCtx).recordDrop(ident, nss, uuid, commitTime.value());
});
+ }
CollectionCatalog::get(opCtx)->dropCollection(
opCtx, collection, opCtx->getServiceContext()->getStorageEngine()->supportsPendingDrops());
diff --git a/src/mongo/db/catalog/virtual_collection_impl.cpp b/src/mongo/db/catalog/virtual_collection_impl.cpp
index 1c7e2f4e341..bd5c80d50c5 100644
--- a/src/mongo/db/catalog/virtual_collection_impl.cpp
+++ b/src/mongo/db/catalog/virtual_collection_impl.cpp
@@ -31,6 +31,7 @@
#include "mongo/db/catalog/collection_impl.h"
#include "mongo/db/catalog/collection_options.h"
+#include "mongo/db/catalog/index_catalog_impl.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/storage/external_record_store.h"
@@ -43,8 +44,9 @@ VirtualCollectionImpl::VirtualCollectionImpl(OperationContext* opCtx,
std::unique_ptr<ExternalRecordStore> recordStore)
: _nss(nss),
_options(options),
- _recordStore(std::move(recordStore)),
- _collator(CollectionImpl::parseCollation(opCtx, nss, options.collation)) {
+ _shared(std::make_shared<SharedState>(
+ std::move(recordStore), CollectionImpl::parseCollation(opCtx, nss, options.collation))),
+ _indexCatalog(std::make_unique<IndexCatalogImpl>()) {
tassert(6968503,
"Cannot create _id index for a virtual collection",
options.autoIndexId == CollectionOptions::NO && options.idIndex.isEmpty());
@@ -57,16 +59,4 @@ std::shared_ptr<Collection> VirtualCollectionImpl::make(OperationContext* opCtx,
return std::make_shared<VirtualCollectionImpl>(
opCtx, nss, options, std::make_unique<ExternalRecordStore>(nss.ns(), vopts));
}
-
-std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> VirtualCollectionImpl::makePlanExecutor(
- OperationContext* opCtx,
- const CollectionPtr& yieldableCollection,
- PlanYieldPolicy::YieldPolicy yieldPolicy,
- ScanDirection scanDirection,
- const boost::optional<RecordId>& resumeAfterRecordId) const {
- // TODO SERVER-69683 Implement this function when implementing MultiBsonStreamCursor since
- // we can scan when it's done. We don't support scan yet.
- unimplementedTasserted();
- return nullptr;
-}
} // namespace mongo
diff --git a/src/mongo/db/catalog/virtual_collection_impl.h b/src/mongo/db/catalog/virtual_collection_impl.h
index a490f4f9ee1..499e6add895 100644
--- a/src/mongo/db/catalog/virtual_collection_impl.h
+++ b/src/mongo/db/catalog/virtual_collection_impl.h
@@ -50,19 +50,20 @@ public:
const CollectionOptions& options,
std::unique_ptr<ExternalRecordStore> recordStore);
+ VirtualCollectionImpl(const VirtualCollectionImpl&) = default;
+
~VirtualCollectionImpl() = default;
const VirtualCollectionOptions& getVirtualCollectionOptions() const {
- return _recordStore->getOptions();
+ return _shared->_recordStore->getOptions();
}
std::shared_ptr<Collection> clone() const final {
- unimplementedTasserted();
- return nullptr;
+ return std::make_shared<VirtualCollectionImpl>(*this);
}
SharedCollectionDecorations* getSharedDecorations() const final {
- return nullptr;
+ return &_shared->_sharedDecorations;
}
Status initFromExisting(OperationContext* opCtx,
@@ -96,19 +97,21 @@ public:
}
const IndexCatalog* getIndexCatalog() const final {
- return nullptr;
+ return _indexCatalog.get();
}
IndexCatalog* getIndexCatalog() final {
- return nullptr;
+ return _indexCatalog.get();
}
RecordStore* getRecordStore() const final {
- return _recordStore.get();
+ return _shared->_recordStore.get();
}
+ // A virtual collection can't have an 'ident' because 'ident' is an identifier to a WT table
+ // which a virtual colelction does not have. So returns nullptr.
std::shared_ptr<Ident> getSharedIdent() const final {
- return _recordStore->getSharedIdent();
+ return nullptr;
}
void setIdent(std::shared_ptr<Ident> newIdent) final {
@@ -149,7 +152,7 @@ public:
std::unique_ptr<SeekableRecordCursor> getCursor(OperationContext* opCtx,
bool forward = true) const final {
- return _recordStore->getCursor(opCtx, forward);
+ return _shared->_recordStore->getCursor(opCtx, forward);
}
void deleteDocument(
@@ -368,35 +371,26 @@ public:
}
int getTotalIndexCount() const final {
- unimplementedTasserted();
return 0;
}
int getCompletedIndexCount() const final {
- unimplementedTasserted();
return 0;
}
BSONObj getIndexSpec(StringData indexName) const final {
- unimplementedTasserted();
return BSONObj();
}
- void getAllIndexes(std::vector<std::string>* names) const final {
- unimplementedTasserted();
- }
+ void getAllIndexes(std::vector<std::string>* names) const final {}
- void getReadyIndexes(std::vector<std::string>* names) const final {
- unimplementedTasserted();
- }
+ void getReadyIndexes(std::vector<std::string>* names) const final {}
bool isIndexPresent(StringData indexName) const final {
- unimplementedTasserted();
return false;
}
bool isIndexReady(StringData indexName) const final {
- unimplementedTasserted();
return false;
}
@@ -428,15 +422,15 @@ public:
}
long long numRecords(OperationContext* opCtx) const final {
- return _recordStore->numRecords(opCtx);
+ return _shared->_recordStore->numRecords(opCtx);
}
long long dataSize(OperationContext* opCtx) const final {
- return _recordStore->dataSize(opCtx);
+ return _shared->_recordStore->dataSize(opCtx);
}
bool isEmpty(OperationContext* opCtx) const final {
- return _recordStore->dataSize(opCtx) == 0LL;
+ return _shared->_recordStore->dataSize(opCtx) == 0LL;
}
inline int averageObjectSize(OperationContext* opCtx) const {
@@ -478,7 +472,7 @@ public:
* Collection is destroyed.
*/
const CollatorInterface* getDefaultCollator() const final {
- return _collator.get();
+ return _shared->_collator.get();
}
const CollectionOptions& getCollectionOptions() const final {
@@ -491,12 +485,17 @@ public:
return Status(ErrorCodes::UnknownError, "unknown");
}
+ // This method is used in context of rollback and index build which are not supported for a
+ // virtual collection.
std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> makePlanExecutor(
OperationContext* opCtx,
const CollectionPtr& yieldableCollection,
PlanYieldPolicy::YieldPolicy yieldPolicy,
ScanDirection scanDirection,
- const boost::optional<RecordId>& resumeAfterRecordId) const final;
+ const boost::optional<RecordId>& resumeAfterRecordId) const final {
+ unimplementedTasserted();
+ return nullptr;
+ }
void indexBuildSuccess(OperationContext* opCtx, IndexCatalogEntry* index) final {
unimplementedTasserted();
@@ -509,13 +508,32 @@ private:
MONGO_UNIMPLEMENTED_TASSERT(6968504);
}
+ struct SharedState {
+ SharedState(std::unique_ptr<ExternalRecordStore> recordStore,
+ std::unique_ptr<CollatorInterface> collator)
+ : _recordStore(std::move(recordStore)), _collator(std::move(collator)) {}
+
+ ~SharedState() = default;
+
+ std::unique_ptr<ExternalRecordStore> _recordStore;
+
+ // This object is decorable and decorated with unversioned data related to the collection.
+ // Not associated with any particular Collection instance for the collection, but shared
+ // across all instances for the same collection. This is a vehicle for users of a collection
+ // to cache unversioned state for a collection that is accessible across all of the
+ // Collection instances.
+ SharedCollectionDecorations _sharedDecorations;
+
+ // The default collation which is applied to operations and indices which have no collation
+ // of their own. The collection's validator will respect this collation. If null, the
+ // default collation is simple binary compare.
+ std::unique_ptr<CollatorInterface> _collator;
+ };
+
NamespaceString _nss;
CollectionOptions _options;
- std::unique_ptr<ExternalRecordStore> _recordStore;
- // The default collation which is applied to operations and indices which have no collation
- // of their own. The collection's validator will respect this collation. If null, the
- // default collation is simple binary compare.
- std::unique_ptr<CollatorInterface> _collator;
+ std::shared_ptr<SharedState> _shared;
+ clonable_ptr<IndexCatalog> _indexCatalog;
};
} // namespace mongo
diff --git a/src/mongo/db/catalog/virtual_collection_options.h b/src/mongo/db/catalog/virtual_collection_options.h
index 433976ab3e5..cd8a3a01936 100644
--- a/src/mongo/db/catalog/virtual_collection_options.h
+++ b/src/mongo/db/catalog/virtual_collection_options.h
@@ -29,46 +29,33 @@
#pragma once
-#include <cstdint>
-#include <fmt/format.h>
#include <string>
#include <vector>
#include "mongo/base/string_data.h"
-#include "mongo/db/pipeline/aggregate_command_gen.h"
#include "mongo/db/pipeline/external_data_source_option_gen.h"
#include "mongo/util/assert_util.h"
namespace mongo {
-
/**
* Metadata for external data source.
*/
struct ExternalDataSourceMetadata {
static constexpr auto kUrlProtocolFile = "file://"_sd;
-#ifndef _WIN32
- static constexpr auto kDefaultFileUrlPrefix = "file:///tmp/"_sd;
-#else
- static constexpr auto kDefaultFileUrlPrefix = "file:////./pipe/"_sd;
-#endif
- ExternalDataSourceMetadata(const std::string& url,
- StorageTypeEnum storageType,
- FileTypeEnum fileType)
- : storageType(storageType), fileType(fileType) {
- using namespace fmt::literals;
+ ExternalDataSourceMetadata(StringData urlStr,
+ StorageTypeEnum storageTypeEnum,
+ FileTypeEnum fileTypeEnum)
+ : url(urlStr), storageType(storageTypeEnum), fileType(fileTypeEnum) {
uassert(6968500,
- "File url must start with {}"_format(kDefaultFileUrlPrefix),
- url.find(kDefaultFileUrlPrefix.toString()) == 0);
+ "File url must start with {}"_format(kUrlProtocolFile),
+ urlStr.startsWith(kUrlProtocolFile));
uassert(6968501, "Storage type must be 'pipe'", storageType == StorageTypeEnum::pipe);
uassert(6968502, "File type must be 'bson'", fileType == FileTypeEnum::bson);
-
- // Strip off the protocol prefix.
- this->url = url.substr(kUrlProtocolFile.size());
}
ExternalDataSourceMetadata(const ExternalDataSourceInfo& dataSourceInfo)
- : ExternalDataSourceMetadata(dataSourceInfo.getUrl().toString(),
+ : ExternalDataSourceMetadata(dataSourceInfo.getUrl(),
dataSourceInfo.getStorageType(),
dataSourceInfo.getFileType()) {}
diff --git a/src/mongo/db/change_stream_pre_images_collection_manager.cpp b/src/mongo/db/change_stream_pre_images_collection_manager.cpp
index 103ce9a2fb1..7d606635708 100644
--- a/src/mongo/db/change_stream_pre_images_collection_manager.cpp
+++ b/src/mongo/db/change_stream_pre_images_collection_manager.cpp
@@ -77,6 +77,7 @@ boost::optional<std::int64_t> getExpireAfterSecondsFromChangeStreamOptions(
// Returns pre-images expiry time in milliseconds since the epoch time if configured, boost::none
// otherwise.
boost::optional<Date_t> getPreImageExpirationTime(OperationContext* opCtx, Date_t currentTime) {
+ invariant(!change_stream_serverless_helpers::isChangeCollectionsModeActive());
boost::optional<std::int64_t> expireAfterSeconds = boost::none;
// Get the expiration time directly from the change stream manager.
@@ -140,7 +141,6 @@ void ChangeStreamPreImagesCollectionManager::insertPreImage(OperationContext* op
<< preImage.getId().getApplyOpsIndex(),
preImage.getId().getApplyOpsIndex() >= 0);
- // TODO SERVER-66642 Consider using internal test-tenant id if applicable.
const auto preImagesCollectionNamespace = NamespaceString::makePreImageCollectionNSS(
change_stream_serverless_helpers::resolveTenantId(tenantId));
@@ -231,49 +231,15 @@ boost::optional<UUID> findNextCollectionUUID(OperationContext* opCtx,
* | applyIndex: 0 | | applyIndex: 0 | | applyIndex: 0 | | applyIndex: 1 |
* +-------------------+ +-------------------+ +-------------------+ +-------------------+
*/
-size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx,
- Date_t currentTimeForTimeBasedExpiration) {
- // Acquire intent-exclusive lock on the pre-images collection. Early exit if the collection
- // doesn't exist.
- // TODO SERVER-66642 Account for multitenancy.
- AutoGetCollection autoColl(
- opCtx, NamespaceString::makePreImageCollectionNSS(boost::none), MODE_IX);
- const auto& preImagesColl = autoColl.getCollection();
- if (!preImagesColl) {
- return 0;
- }
-
- // Do not run the job on secondaries.
- if (!repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(
- opCtx, NamespaceString::kConfigDb)) {
- return 0;
- }
-
- // Get the timestamp of the earliest oplog entry.
- const auto currentEarliestOplogEntryTs =
- repl::StorageInterface::get(opCtx->getServiceContext())->getEarliestOplogTimestamp(opCtx);
-
- const bool isBatchedRemoval = gBatchedExpiredChangeStreamPreImageRemoval.load();
+size_t _deleteExpiredChangeStreamPreImagesCommon(OperationContext* opCtx,
+ const CollectionPtr& preImageColl,
+ const MatchExpression* filterPtr,
+ Timestamp maxRecordIdTimestamp) {
size_t numberOfRemovals = 0;
- const auto preImageExpirationTime = change_stream_pre_image_helpers::getPreImageExpirationTime(
- opCtx, currentTimeForTimeBasedExpiration);
-
- // Configure the filter for the case when expiration parameter is set.
- OrMatchExpression filter;
- const MatchExpression* filterPtr = nullptr;
- if (preImageExpirationTime) {
- filter.add(
- std::make_unique<LTMatchExpression>("_id.ts"_sd, Value(currentEarliestOplogEntryTs)));
- filter.add(std::make_unique<LTEMatchExpression>("operationTime"_sd,
- Value(*preImageExpirationTime)));
- filterPtr = &filter;
- }
- const bool shouldReturnEofOnFilterMismatch = preImageExpirationTime.has_value();
-
- // TODO SERVER-66642 Account for multitenancy.
+ const bool isBatchedRemoval = gBatchedExpiredChangeStreamPreImageRemoval.load();
boost::optional<UUID> currentCollectionUUID = boost::none;
while ((currentCollectionUUID =
- findNextCollectionUUID(opCtx, &preImagesColl, currentCollectionUUID))) {
+ findNextCollectionUUID(opCtx, &preImageColl, currentCollectionUUID))) {
writeConflictRetry(
opCtx,
"ChangeStreamExpiredPreImagesRemover",
@@ -288,22 +254,14 @@ size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx,
}
RecordIdBound minRecordId(
toRecordId(ChangeStreamPreImageId(*currentCollectionUUID, Timestamp(), 0)));
-
- // If the expiration parameter is set, the 'maxRecord' is set to the maximum
- // RecordId for this collection. Whether the pre-image has to be deleted will be
- // determined by the filtering MatchExpression.
- //
- // If the expiration parameter is not set, then the last expired pre-image timestamp
- // equals to one increment before the 'currentEarliestOplogEntryTs'.
- RecordIdBound maxRecordId = RecordIdBound(toRecordId(ChangeStreamPreImageId(
- *currentCollectionUUID,
- preImageExpirationTime ? Timestamp::max()
- : Timestamp(currentEarliestOplogEntryTs.asULL() - 1),
- std::numeric_limits<int64_t>::max())));
+ RecordIdBound maxRecordId = RecordIdBound(
+ toRecordId(ChangeStreamPreImageId(*currentCollectionUUID,
+ maxRecordIdTimestamp,
+ std::numeric_limits<int64_t>::max())));
auto exec = InternalPlanner::deleteWithCollectionScan(
opCtx,
- &preImagesColl,
+ &preImageColl,
std::move(params),
PlanYieldPolicy::YieldPolicy::YIELD_AUTO,
InternalPlanner::Direction::FORWARD,
@@ -312,12 +270,83 @@ size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx,
CollectionScanParams::ScanBoundInclusion::kIncludeBothStartAndEndRecords,
std::move(batchedDeleteParams),
filterPtr,
- shouldReturnEofOnFilterMismatch);
+ filterPtr != nullptr);
numberOfRemovals += exec->executeDelete();
});
}
return numberOfRemovals;
}
+
+size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx,
+ Date_t currentTimeForTimeBasedExpiration) {
+ // Acquire intent-exclusive lock on the change collection.
+ AutoGetCollection preImageColl(
+ opCtx, NamespaceString::makePreImageCollectionNSS(boost::none), MODE_IX);
+
+ // Early exit if the collection doesn't exist or running on a secondary.
+ if (!preImageColl ||
+ !repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(
+ opCtx, NamespaceString::kConfigDb)) {
+ return 0;
+ }
+
+ // Get the timestamp of the earliest oplog entry.
+ const auto currentEarliestOplogEntryTs =
+ repl::StorageInterface::get(opCtx->getServiceContext())->getEarliestOplogTimestamp(opCtx);
+
+ const auto preImageExpirationTime = change_stream_pre_image_helpers::getPreImageExpirationTime(
+ opCtx, currentTimeForTimeBasedExpiration);
+
+ // Configure the filter for the case when expiration parameter is set.
+ if (preImageExpirationTime) {
+ OrMatchExpression filter;
+ filter.add(
+ std::make_unique<LTMatchExpression>("_id.ts"_sd, Value(currentEarliestOplogEntryTs)));
+ filter.add(std::make_unique<LTEMatchExpression>("operationTime"_sd,
+ Value(*preImageExpirationTime)));
+ // If 'preImageExpirationTime' is set, set 'maxRecordIdTimestamp' is set to the maximum
+ // RecordId for this collection. Whether the pre-image has to be deleted will be determined
+ // by the 'filter' parameter.
+ return _deleteExpiredChangeStreamPreImagesCommon(
+ opCtx, *preImageColl, &filter, Timestamp::max() /* maxRecordIdTimestamp */);
+ }
+
+ // 'preImageExpirationTime' is not set, so the last expired pre-image timestamp is less than
+ // 'currentEarliestOplogEntryTs'.
+ return _deleteExpiredChangeStreamPreImagesCommon(
+ opCtx,
+ *preImageColl,
+ nullptr /* filterPtr */,
+ Timestamp(currentEarliestOplogEntryTs.asULL() - 1) /* maxRecordIdTimestamp */);
+}
+
+size_t deleteExpiredChangeStreamPreImagesForTenants(OperationContext* opCtx,
+ const TenantId& tenantId,
+ Date_t currentTimeForTimeBasedExpiration) {
+
+ // Acquire intent-exclusive lock on the change collection.
+ AutoGetCollection preImageColl(opCtx,
+ NamespaceString::makePreImageCollectionNSS(
+ change_stream_serverless_helpers::resolveTenantId(tenantId)),
+ MODE_IX);
+
+ // Early exit if the collection doesn't exist or running on a secondary.
+ if (!preImageColl ||
+ !repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(
+ opCtx, NamespaceString::kConfigDb)) {
+ return 0;
+ }
+
+ auto expiredAfterSeconds = change_stream_serverless_helpers::getExpireAfterSeconds(tenantId);
+ LTEMatchExpression filter{
+ "operationTime"_sd,
+ Value(currentTimeForTimeBasedExpiration - Seconds(expiredAfterSeconds))};
+
+ // Set the 'maxRecordIdTimestamp' parameter (upper scan boundary) to maximum possible. Whether
+ // the pre-image has to be deleted will be determined by the 'filter' parameter.
+ return _deleteExpiredChangeStreamPreImagesCommon(
+ opCtx, *preImageColl, &filter, Timestamp::max() /* maxRecordIdTimestamp */);
+}
} // namespace
void ChangeStreamPreImagesCollectionManager::performExpiredChangeStreamPreImagesRemovalPass(
@@ -342,9 +371,20 @@ void ChangeStreamPreImagesCollectionManager::performExpiredChangeStreamPreImages
ServiceContext::UniqueOperationContext opCtx;
try {
opCtx = client->makeOperationContext();
+ size_t numberOfRemovals = 0;
+
+ if (change_stream_serverless_helpers::isChangeCollectionsModeActive()) {
+ const auto tenantIds =
+ change_stream_serverless_helpers::getConfigDbTenants(opCtx.get());
+ for (const auto& tenantId : tenantIds) {
+ numberOfRemovals += deleteExpiredChangeStreamPreImagesForTenants(
+ opCtx.get(), tenantId, currentTimeForTimeBasedExpiration);
+ }
+ } else {
+ numberOfRemovals =
+ deleteExpiredChangeStreamPreImages(opCtx.get(), currentTimeForTimeBasedExpiration);
+ }
- auto numberOfRemovals =
- deleteExpiredChangeStreamPreImages(opCtx.get(), currentTimeForTimeBasedExpiration);
if (numberOfRemovals > 0) {
LOGV2_DEBUG(5869104,
3,
diff --git a/src/mongo/db/clientcursor.cpp b/src/mongo/db/clientcursor.cpp
index 2a471eb1288..c9994c1ea7f 100644
--- a/src/mongo/db/clientcursor.cpp
+++ b/src/mongo/db/clientcursor.cpp
@@ -40,6 +40,7 @@
#include "mongo/db/auth/privilege.h"
#include "mongo/db/client.h"
#include "mongo/db/commands.h"
+#include "mongo/db/commands/external_data_source_scope_guard.h"
#include "mongo/db/commands/server_status.h"
#include "mongo/db/commands/server_status_metric.h"
#include "mongo/db/curop.h"
@@ -65,6 +66,10 @@ static CounterMetric cursorStatsTimedOut{"cursor.timedOut"};
static CounterMetric cursorStatsTotalOpened{"cursor.totalOpened"};
static CounterMetric cursorStatsMoreThanOneBatch{"cursor.moreThanOneBatch"};
+const ClientCursor::Decoration<std::shared_ptr<ExternalDataSourceScopeGuard>>
+ ExternalDataSourceScopeGuard::get =
+ ClientCursor::declareDecoration<std::shared_ptr<ExternalDataSourceScopeGuard>>();
+
ClientCursor::ClientCursor(ClientCursorParams params,
CursorId cursorId,
OperationContext* operationUsingCursor,
@@ -133,6 +138,9 @@ void ClientCursor::dispose(OperationContext* opCtx) {
}
_exec->dispose(opCtx);
+ // Update opCtx of the decorated ExternalDataSourceScopeGuard object so that it can drop virtual
+ // collections in the new 'opCtx'.
+ ExternalDataSourceScopeGuard::updateOperationContext(this, opCtx);
_disposed = true;
}
diff --git a/src/mongo/db/clientcursor.h b/src/mongo/db/clientcursor.h
index 89239e6230d..4e92e26d066 100644
--- a/src/mongo/db/clientcursor.h
+++ b/src/mongo/db/clientcursor.h
@@ -117,7 +117,7 @@ struct ClientCursorParams {
* caller as "no timeout", it will be automatically destroyed by its cursor manager after a period
* of inactivity.
*/
-class ClientCursor {
+class ClientCursor : public Decorable<ClientCursor> {
ClientCursor(const ClientCursor&) = delete;
ClientCursor& operator=(const ClientCursor&) = delete;
diff --git a/src/mongo/db/commands/SConscript b/src/mongo/db/commands/SConscript
index 0074dc859c6..19de67d69b6 100644
--- a/src/mongo/db/commands/SConscript
+++ b/src/mongo/db/commands/SConscript
@@ -367,6 +367,7 @@ env.Library(
'$BUILD_DIR/mongo/db/repl/replica_set_messages',
'$BUILD_DIR/mongo/db/repl/tenant_migration_access_blocker',
'$BUILD_DIR/mongo/db/rw_concern_d',
+ '$BUILD_DIR/mongo/db/s/query_analysis_writer',
'$BUILD_DIR/mongo/db/server_base',
'$BUILD_DIR/mongo/db/server_feature_flags',
'$BUILD_DIR/mongo/db/session/session_catalog_mongod',
@@ -789,6 +790,7 @@ env.CppUnitTest(
target="db_commands_test",
source=[
"create_indexes_test.cpp",
+ "external_data_source_commands_test.cpp",
"index_filter_commands_test.cpp",
"fle_compact_test.cpp",
"list_collections_filter_test.cpp",
@@ -819,7 +821,9 @@ env.CppUnitTest(
"$BUILD_DIR/mongo/db/repl/replmocks",
"$BUILD_DIR/mongo/db/repl/storage_interface_impl",
"$BUILD_DIR/mongo/db/service_context_d_test_fixture",
+ "$BUILD_DIR/mongo/db/storage/record_store_base",
'$BUILD_DIR/mongo/idl/idl_parser',
+ '$BUILD_DIR/mongo/util/version_impl',
"cluster_server_parameter_commands_invocation",
"core",
"create_command",
diff --git a/src/mongo/db/commands/count_cmd.cpp b/src/mongo/db/commands/count_cmd.cpp
index 21ec11df16f..9a961c1ed0b 100644
--- a/src/mongo/db/commands/count_cmd.cpp
+++ b/src/mongo/db/commands/count_cmd.cpp
@@ -44,6 +44,7 @@
#include "mongo/db/query/plan_summary_stats.h"
#include "mongo/db/query/view_response_formatter.h"
#include "mongo/db/s/collection_sharding_state.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/logv2/log.h"
#include "mongo/util/database_name_util.h"
@@ -249,6 +250,15 @@ public:
invocation->markMirrored();
}
+ if (analyze_shard_key::supportsPersistingSampledQueries() && request.getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addCountQuery(*request.getSampleId(),
+ nss,
+ request.getQuery(),
+ request.getCollation().value_or(BSONObj()))
+ .getAsync([](auto) {});
+ }
+
if (ctx->getView()) {
auto viewAggregation = countCommandAsAggregationCommand(request, nss);
diff --git a/src/mongo/db/commands/dbcheck.cpp b/src/mongo/db/commands/dbcheck.cpp
index 1b1110abb29..f8aafcae484 100644
--- a/src/mongo/db/commands/dbcheck.cpp
+++ b/src/mongo/db/commands/dbcheck.cpp
@@ -65,9 +65,8 @@ repl::OpTime _logOp(OperationContext* opCtx,
repl::MutableOplogEntry oplogEntry;
oplogEntry.setOpType(repl::OpTypeEnum::kCommand);
oplogEntry.setNss(nss);
- if (uuid) {
- oplogEntry.setUuid(*uuid);
- }
+ oplogEntry.setTid(nss.tenantId());
+ oplogEntry.setUuid(uuid);
oplogEntry.setObject(obj);
AutoGetOplog oplogWrite(opCtx, OplogAccessMode::kWrite);
return writeConflictRetry(
@@ -144,7 +143,6 @@ struct DbCheckCollectionInfo {
int64_t maxDocsPerBatch;
int64_t maxBytesPerBatch;
int64_t maxBatchTimeMillis;
- bool snapshotRead;
WriteConcernOptions writeConcern;
};
@@ -167,6 +165,8 @@ std::unique_ptr<DbCheckRun> singleCollectionRun(OperationContext* opCtx,
"Cannot run dbCheck on " + nss.toString() + " because it is not replicated",
nss.isReplicated());
+ uassert(6769500, "dbCheck no longer supports snapshotRead:false", invocation.getSnapshotRead());
+
const auto start = invocation.getMinKey();
const auto end = invocation.getMaxKey();
const auto maxCount = invocation.getMaxCount();
@@ -184,7 +184,6 @@ std::unique_ptr<DbCheckRun> singleCollectionRun(OperationContext* opCtx,
maxDocsPerBatch,
maxBytesPerBatch,
maxBatchTimeMillis,
- invocation.getSnapshotRead(),
invocation.getBatchWriteConcern()};
auto result = std::make_unique<DbCheckRun>();
result->push_back(info);
@@ -201,6 +200,8 @@ std::unique_ptr<DbCheckRun> fullDatabaseRun(OperationContext* opCtx,
AutoGetDb agd(opCtx, dbName, MODE_IS);
uassert(ErrorCodes::NamespaceNotFound, "Database " + dbName.db() + " not found", agd.getDb());
+ uassert(6769501, "dbCheck no longer supports snapshotRead:false", invocation.getSnapshotRead());
+
const int64_t max = std::numeric_limits<int64_t>::max();
const auto rate = invocation.getMaxCountPerSecond();
const auto maxDocsPerBatch = invocation.getMaxDocsPerBatch();
@@ -220,7 +221,6 @@ std::unique_ptr<DbCheckRun> fullDatabaseRun(OperationContext* opCtx,
maxDocsPerBatch,
maxBytesPerBatch,
maxBatchTimeMillis,
- invocation.getSnapshotRead(),
invocation.getBatchWriteConcern()};
result->push_back(info);
return true;
@@ -241,7 +241,8 @@ std::unique_ptr<DbCheckRun> getRun(OperationContext* opCtx,
// Get rid of generic command fields.
for (const auto& elem : obj) {
- if (!isGenericArgument(elem.fieldNameStringData())) {
+ const auto& fieldName = elem.fieldNameStringData();
+ if (!isGenericArgument(fieldName)) {
builder.append(elem);
}
}
@@ -251,11 +252,17 @@ std::unique_ptr<DbCheckRun> getRun(OperationContext* opCtx,
// If the dbCheck argument is a string, this is the per-collection form.
if (toParse["dbCheck"].type() == BSONType::String) {
return singleCollectionRun(
- opCtx, dbName, DbCheckSingleInvocation::parse(IDLParserContext(""), toParse));
+ opCtx,
+ dbName,
+ DbCheckSingleInvocation::parse(
+ IDLParserContext("", false /*apiStrict*/, dbName.tenantId()), toParse));
} else {
// Otherwise, it's the database-wide form.
return fullDatabaseRun(
- opCtx, dbName, DbCheckAllInvocation::parse(IDLParserContext(""), toParse));
+ opCtx,
+ dbName,
+ DbCheckAllInvocation::parse(
+ IDLParserContext("", false /*apiStrict*/, dbName.tenantId()), toParse));
}
}
@@ -486,122 +493,77 @@ private:
const BSONKey& first,
int64_t batchDocs,
int64_t batchBytes) {
- auto lockMode = MODE_S;
- if (info.snapshotRead) {
- // Each batch will read at the latest no-overlap point, which is the all_durable
- // timestamp on primaries. We assume that the history window on secondaries is always
- // longer than the time it takes between starting and replicating a batch on the
- // primary. Otherwise, the readTimestamp will not be available on a secondary by the
- // time it processes the oplog entry.
- lockMode = MODE_IS;
- opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kNoOverlap);
+ // Each batch will read at the latest no-overlap point, which is the all_durable timestamp
+ // on primaries. We assume that the history window on secondaries is always longer than the
+ // time it takes between starting and replicating a batch on the primary. Otherwise, the
+ // readTimestamp will not be available on a secondary by the time it processes the oplog
+ // entry.
+ opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kNoOverlap);
+
+ // dbCheck writes to the oplog, so we need to take an IX lock. We don't need to write to the
+ // collection, however, so we only take an intent lock on it.
+ Lock::GlobalLock glob(opCtx, MODE_IX);
+ AutoGetCollection collection(opCtx, info.nss, MODE_IS);
+
+ if (_stepdownHasOccurred(opCtx, info.nss)) {
+ _done = true;
+ return Status(ErrorCodes::PrimarySteppedDown, "dbCheck terminated due to stepdown");
}
- BatchStats result;
- auto timeoutMs = Milliseconds(gDbCheckCollectionTryLockTimeoutMillis.load());
- const auto initialBackoffMs =
- Milliseconds(gDbCheckCollectionTryLockMinBackoffMillis.load());
- auto backoffMs = initialBackoffMs;
- for (int attempt = 1;; attempt++) {
- try {
- // Try to acquire collection lock with increasing timeout and bounded exponential
- // backoff.
- auto const lockDeadline = Date_t::now() + timeoutMs;
- timeoutMs *= 2;
-
- AutoGetCollection agc(
- opCtx, info.nss, lockMode, AutoGetCollection::Options{}.deadline(lockDeadline));
-
- if (_stepdownHasOccurred(opCtx, info.nss)) {
- _done = true;
- return Status(ErrorCodes::PrimarySteppedDown,
- "dbCheck terminated due to stepdown");
- }
+ if (!collection) {
+ const auto msg = "Collection under dbCheck no longer exists";
+ return {ErrorCodes::NamespaceNotFound, msg};
+ }
- const auto& collection =
- CollectionCatalog::get(opCtx)->lookupCollectionByNamespace(opCtx, info.nss);
- if (!collection) {
- const auto msg = "Collection under dbCheck no longer exists";
- return {ErrorCodes::NamespaceNotFound, msg};
- }
+ auto readTimestamp = opCtx->recoveryUnit()->getPointInTimeReadTimestamp(opCtx);
+ uassert(ErrorCodes::SnapshotUnavailable,
+ "No snapshot available yet for dbCheck",
+ readTimestamp);
+ auto minVisible = collection->getMinimumVisibleSnapshot();
+ if (minVisible && *readTimestamp < *collection->getMinimumVisibleSnapshot()) {
+ return {ErrorCodes::SnapshotUnavailable,
+ str::stream() << "Unable to read from collection " << info.nss
+ << " due to pending catalog changes"};
+ }
- auto readTimestamp = opCtx->recoveryUnit()->getPointInTimeReadTimestamp(opCtx);
- auto minVisible = collection->getMinimumVisibleSnapshot();
- if (readTimestamp && minVisible &&
- *readTimestamp < *collection->getMinimumVisibleSnapshot()) {
- return {ErrorCodes::SnapshotUnavailable,
- str::stream() << "Unable to read from collection " << info.nss
- << " due to pending catalog changes"};
- }
+ boost::optional<DbCheckHasher> hasher;
+ try {
+ hasher.emplace(opCtx,
+ *collection,
+ first,
+ info.end,
+ std::min(batchDocs, info.maxCount),
+ std::min(batchBytes, info.maxSize));
+ } catch (const DBException& e) {
+ return e.toStatus();
+ }
- boost::optional<DbCheckHasher> hasher;
- try {
- hasher.emplace(opCtx,
- collection,
- first,
- info.end,
- std::min(batchDocs, info.maxCount),
- std::min(batchBytes, info.maxSize));
- } catch (const DBException& e) {
- return e.toStatus();
- }
+ const auto batchDeadline = Date_t::now() + Milliseconds(info.maxBatchTimeMillis);
+ Status status = hasher->hashAll(opCtx, batchDeadline);
- const auto batchDeadline = Date_t::now() + Milliseconds(info.maxBatchTimeMillis);
- Status status = hasher->hashAll(opCtx, batchDeadline);
+ if (!status.isOK()) {
+ return status;
+ }
- if (!status.isOK()) {
- return status;
- }
+ std::string md5 = hasher->total();
- std::string md5 = hasher->total();
-
- DbCheckOplogBatch batch;
- batch.setType(OplogEntriesEnum::Batch);
- batch.setNss(info.nss);
- batch.setMd5(md5);
- batch.setMinKey(first);
- batch.setMaxKey(BSONKey(hasher->lastKey()));
- batch.setReadTimestamp(readTimestamp);
-
- // Send information on this batch over the oplog.
- result.time = _logOp(opCtx, info.nss, collection->uuid(), batch.toBSON());
- result.readTimestamp = readTimestamp;
-
- result.nDocs = hasher->docsSeen();
- result.nBytes = hasher->bytesSeen();
- result.lastKey = hasher->lastKey();
- result.md5 = md5;
-
- break;
- } catch (const ExceptionFor<ErrorCodes::LockTimeout>& e) {
- if (attempt > gDbCheckCollectionTryLockMaxAttempts.load()) {
- return StatusWith<BatchStats>(e.code(),
- "Unable to acquire the collection lock");
- }
+ DbCheckOplogBatch batch;
+ batch.setType(OplogEntriesEnum::Batch);
+ batch.setNss(info.nss);
+ batch.setMd5(md5);
+ batch.setMinKey(first);
+ batch.setMaxKey(BSONKey(hasher->lastKey()));
+ batch.setReadTimestamp(readTimestamp);
- // Bounded exponential backoff between tryLocks.
- opCtx->sleepFor(backoffMs);
- const auto maxBackoffMillis =
- Milliseconds(gDbCheckCollectionTryLockMaxBackoffMillis.load());
- if (backoffMs < maxBackoffMillis) {
- auto backoff = durationCount<Milliseconds>(backoffMs);
- auto initialBackoff = durationCount<Milliseconds>(initialBackoffMs);
- backoff *= initialBackoff;
- backoffMs = Milliseconds(backoff);
- }
- if (backoffMs > maxBackoffMillis) {
- backoffMs = maxBackoffMillis;
- }
- LOGV2_DEBUG(6175700,
- 1,
- "Could not acquire collection lock, retrying",
- "ns"_attr = info.nss.ns(),
- "batchRangeMin"_attr = info.start.obj(),
- "batchRangeMax"_attr = info.end.obj(),
- "attempt"_attr = attempt,
- "backoff"_attr = backoffMs);
- }
- }
+ // Send information on this batch over the oplog.
+ BatchStats result;
+ result.time = _logOp(opCtx, info.nss, collection->uuid(), batch.toBSON());
+ result.readTimestamp = readTimestamp;
+
+ result.nDocs = hasher->docsSeen();
+ result.nBytes = hasher->bytesSeen();
+ result.lastKey = hasher->lastKey();
+ result.md5 = md5;
return result;
}
@@ -663,7 +625,6 @@ public:
" maxDocsPerBatch: <max number of docs/batch>\n"
" maxBytesPerBatch: <try to keep a batch within max bytes/batch>\n"
" maxBatchTimeMillis: <max time processing a batch in milliseconds>\n"
- " readTimestamp: <bool, read at a timestamp without strong locks> }\n"
"to check a collection.\n"
"Invoke with {dbCheck: 1} to check all collections in the database.";
}
diff --git a/src/mongo/db/commands/dbcommands.cpp b/src/mongo/db/commands/dbcommands.cpp
index 6374a093936..0b6f90d0203 100644
--- a/src/mongo/db/commands/dbcommands.cpp
+++ b/src/mongo/db/commands/dbcommands.cpp
@@ -485,6 +485,10 @@ public:
return Request::kCommandDescription.toString();
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
// Assume that appendCollectionStorageStats() gives us a valid response.
void validateResult(const BSONObj& resultObj) final {}
@@ -587,6 +591,10 @@ public:
CmdDbStats() : TypedCommand(Request::kCommandName, Request::kCommandAlias) {}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
class Invocation final : public InvocationBase {
public:
using InvocationBase::InvocationBase;
diff --git a/src/mongo/db/commands/dbcommands_d.cpp b/src/mongo/db/commands/dbcommands_d.cpp
index 5f42cb45010..5ad0adf46c5 100644
--- a/src/mongo/db/commands/dbcommands_d.cpp
+++ b/src/mongo/db/commands/dbcommands_d.cpp
@@ -248,6 +248,10 @@ public:
return Status::OK();
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
bool run(OperationContext* opCtx,
const DatabaseName& dbName,
const BSONObj& jsobj,
diff --git a/src/mongo/db/commands/distinct.cpp b/src/mongo/db/commands/distinct.cpp
index 070da8d285c..eb599424001 100644
--- a/src/mongo/db/commands/distinct.cpp
+++ b/src/mongo/db/commands/distinct.cpp
@@ -57,6 +57,7 @@
#include "mongo/db/query/query_planner_common.h"
#include "mongo/db/query/view_response_formatter.h"
#include "mongo/db/s/collection_sharding_state.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/views/resolved_view.h"
#include "mongo/logv2/log.h"
#include "mongo/util/database_name_util.h"
@@ -244,6 +245,16 @@ public:
invocation->markMirrored();
}
+ if (analyze_shard_key::supportsPersistingSampledQueries() && parsedDistinct.getSampleId()) {
+ auto cq = parsedDistinct.getQuery();
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addDistinctQuery(*parsedDistinct.getSampleId(),
+ nss,
+ cq->getQueryObj(),
+ cq->getFindCommandRequest().getCollation())
+ .getAsync([](auto) {});
+ }
+
if (ctx->getView()) {
// Relinquish locks. The aggregation command will re-acquire them.
ctx.reset();
diff --git a/src/mongo/db/commands/external_data_source_commands_test.cpp b/src/mongo/db/commands/external_data_source_commands_test.cpp
new file mode 100644
index 00000000000..113495efdc8
--- /dev/null
+++ b/src/mongo/db/commands/external_data_source_commands_test.cpp
@@ -0,0 +1,953 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <fmt/format.h>
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/client/dbclient_cursor.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/pipeline/aggregation_request_helper.h"
+#include "mongo/db/query/query_knobs_gen.h"
+#include "mongo/db/repl/replication_coordinator_mock.h"
+#include "mongo/db/service_context_d_test_fixture.h"
+#include "mongo/db/storage/named_pipe.h"
+#include "mongo/util/assert_util.h"
+#include "mongo/util/scopeguard.h"
+
+namespace mongo {
+namespace {
+using namespace fmt::literals;
+
+class PipeWaiter {
+public:
+ void notify() {
+ {
+ stdx::unique_lock lk(m);
+ pipeCreated = true;
+ }
+ cv.notify_one();
+ }
+
+ void wait() {
+ stdx::unique_lock lk(m);
+ cv.wait(lk, [&] { return pipeCreated; });
+ }
+
+private:
+ Mutex m;
+ stdx::condition_variable cv;
+ bool pipeCreated = false;
+};
+
+class ExternalDataSourceCommandsTest : public ServiceContextMongoDTest {
+protected:
+ void setUp() override {
+ ServiceContextMongoDTest::setUp();
+
+ std::srand(std::time(0));
+
+ const auto service = getServiceContext();
+ auto replCoord =
+ std::make_unique<repl::ReplicationCoordinatorMock>(service, repl::ReplSettings{});
+ ASSERT_OK(replCoord->setFollowerMode(repl::MemberState::RS_PRIMARY));
+ repl::ReplicationCoordinator::set(service, std::move(replCoord));
+ repl::createOplog(_opCtx);
+
+ computeModeEnabled = true;
+ }
+
+ void tearDown() override {
+ computeModeEnabled = false;
+ ServiceContextMongoDTest::tearDown();
+ }
+
+ std::vector<BSONObj> generateRandomSimpleDocs(int count) {
+ std::vector<BSONObj> docs;
+ for (int i = 0; i < count; ++i) {
+ docs.emplace_back(BSON("a" << std::rand() % 10));
+ }
+
+ return docs;
+ }
+
+ // Generates a large readable random string to aid debugging.
+ std::string getRandomReadableLargeString() {
+ int count = std::rand() % 100 + 2024;
+ std::string str(count, '\0');
+ for (int i = 0; i < count; ++i) {
+ str[i] = static_cast<char>(std::rand() % 26) + 'a';
+ }
+
+ return str;
+ }
+
+ std::vector<BSONObj> generateRandomLargeDocs(int count) {
+ std::vector<BSONObj> docs;
+ for (int i = 0; i < count; ++i) {
+ docs.emplace_back(BSON("a" << getRandomReadableLargeString()));
+ }
+
+ return docs;
+ }
+
+ // This verifies that a simple explain aggregate command works. Virtual collections are created
+ // even for explain aggregate command.
+ void verifyExplainAggCommand(DBDirectClient& client, const BSONObj& explainAggCmdObj) {
+ // The first request.
+ BSONObj res;
+ ASSERT_TRUE(client.runCommand(kDatabaseName, explainAggCmdObj.getOwned(), res))
+ << "Expected to succeed but failed. result = {}"_format(res.toString());
+ // Sanity checks of result.
+ ASSERT_EQ(res["ok"].Number(), 1.0)
+ << "Expected to succeed but failed. result = {}"_format(res.toString());
+ }
+
+ ServiceContext::UniqueOperationContext _uniqueOpCtx{makeOperationContext()};
+ OperationContext* _opCtx{_uniqueOpCtx.get()};
+
+ BSONObj explainSingleNamedPipeAggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ explain: true,
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+ BSONObj explainMultipleNamedPipesAggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ explain: true,
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [
+ {url: "file://named_pipe1", storageType: "pipe", fileType: "bson"},
+ {url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}
+ ]
+ }]
+}
+ )");
+
+ static constexpr auto kDatabaseName = "external_data_source";
+};
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleScanAggRequest) {
+ const auto nDocs = std::rand() % 100 + 1;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ // The first request.
+ BSONObj res;
+ ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res));
+
+ // Sanity checks of result.
+ ASSERT_EQ(res["ok"].Number(), 1.0);
+ ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch"));
+
+ // The default batch size is 101 and so all data must be contained in the first batch. cursor.id
+ // == 0 means that no cursor is necessary.
+ ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0);
+ auto resDocs = res["cursor"]["firstBatch"].Array();
+ ASSERT_EQ(resDocs.size(), nDocs);
+ for (int i = 0; i < nDocs; ++i) {
+ ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i]);
+ }
+
+ // The second request. This verifies that virtual collections are cleaned up after the
+ // aggregation request is done.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleScanOverMultipleNamedPipesAggRequest) {
+ // This data set fits into the first batch.
+ const auto nDocs = std::rand() % 50;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ // Pushes data into multiple named pipes. We can't push data into multiple named pipes
+ // simultaneously because writers will be blocked until the reader consumes data. So, we push
+ // data into one named pipe after another.
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter1("named_pipe1");
+ NamedPipeOutput pipeWriter2("named_pipe2");
+ pw.notify();
+ pipeWriter1.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter1.close();
+
+ pipeWriter2.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter2.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [
+ {url: "file://named_pipe1", storageType: "pipe", fileType: "bson"},
+ {url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}
+ ]
+ }]
+}
+ )");
+
+ // The first request.
+ BSONObj res;
+ ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res));
+
+ // Sanity checks of result.
+ ASSERT_EQ(res["ok"].Number(), 1.0);
+ ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch"));
+
+ // The default batch size is 101 and so all data must be contained in the first batch. cursor.id
+ // == 0 means that no cursor is necessary.
+ ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0);
+ auto resDocs = res["cursor"]["firstBatch"].Array();
+ ASSERT_EQ(resDocs.size(), nDocs * 2);
+ for (int i = 0; i < nDocs; ++i) {
+ ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i % nDocs]);
+ }
+
+ // The second request. This verifies that virtual collections are cleaned up after the
+ // aggregation request is done.
+ verifyExplainAggCommand(client, explainMultipleNamedPipesAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleScanOverLargeObjectsAggRequest) {
+ // MultiBsonStreamCursor's default buffer size is 8K and 2K (at minimum) * 20 would be enough to
+ // exceed the initial read. This data set is highly likely to span multiple reads.
+ const auto nDocs = std::rand() % 80 + 20;
+ std::vector<BSONObj> srcDocs = generateRandomLargeDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ // The first request.
+ BSONObj res;
+ ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res));
+
+ // Sanity checks of result.
+ ASSERT_EQ(res["ok"].Number(), 1.0);
+ ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch"));
+
+ // The default batch size is 101 and so all data must be contained in the first batch. cursor.id
+ // == 0 means that no cursor is necessary.
+ ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0);
+ auto resDocs = res["cursor"]["firstBatch"].Array();
+ ASSERT_EQ(resDocs.size(), nDocs);
+ for (int i = 0; i < nDocs; ++i) {
+ ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i]);
+ }
+
+ // The second request. This verifies that virtual collections are cleaned up after the
+ // aggregation request is done.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+// Tests that 'explain' flag works and also tests that the same aggregation request works with the
+// same $_externalDataSources again to see whether there are no remaining virtual collections left
+// behind after the aggregation request is done.
+TEST_F(ExternalDataSourceCommandsTest, ExplainAggRequest) {
+ DBDirectClient client(_opCtx);
+ // The first request.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+
+ // The second request.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleScanMultiBatchAggRequest) {
+ // This 'nDocs' causes a cursor to be created for a simple scan aggregate command.
+ const auto nDocs = std::rand() % 100 + 102;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ // While iterating over the cursor, getMore() request(s) will be sent and the server-side cursor
+ // will be destroyed after all data is exhausted.
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt]);
+ ++resCnt;
+ }
+ ASSERT_EQ(resCnt, nDocs);
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // multi-batch result for an aggregation request.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleMatchAggRequest) {
+ const auto nDocs = std::rand() % 100 + 1;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ // Expected results for {$match: {a: {$lt: 5}}}.
+ std::vector<BSONObj> expectedDocs;
+ std::for_each(srcDocs.begin(), srcDocs.end(), [&](const BSONObj& doc) {
+ if (doc["a"].Int() < 5) {
+ expectedDocs.emplace_back(doc);
+ }
+ });
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [{$match: {a: {$lt: 5}}}],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ ASSERT_BSONOBJ_EQ(doc, expectedDocs[resCnt]);
+ ++resCnt;
+ }
+ ASSERT_EQ(resCnt, expectedDocs.size());
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the aggregation request is done.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, ScanOverRandomInvalidDataAggRequest) {
+ const auto nDocs = std::rand() % 100 + 1;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ const size_t failPoint = std::rand() % nDocs;
+ pipeWriter.open();
+ for (size_t i = 0; i < srcDocs.size(); ++i) {
+ if (i == failPoint) {
+ // Intentionally pushes invalid data at the fail point so that an error happens at
+ // the reader-side
+ pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize() / 2);
+ } else {
+ pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize());
+ }
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [{$match: {a: {$lt: 5}}}],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ BSONObj res;
+ ASSERT_FALSE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res));
+ ASSERT_EQ(res["ok"].Number(), 0.0);
+ // The fail point is randomly chosen and different error codes are expected, depending on the
+ // chosen fail point.
+ ASSERT_NE(ErrorCodes::Error(res["code"].Int()), ErrorCodes::OK);
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the aggregation request fails.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, ScanOverRandomInvalidDataAtSecondBatchAggRequest) {
+ // This 'nDocs' causes a cursor to be created for a simple scan aggregate command.
+ const auto nDocs = std::rand() % 100 + 102; // 201 >= nDocs >= 102
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ // The fail point occurs at the second batch.
+ const size_t failPoint = 101 + std::rand() % (nDocs - 101); // 200 >= failPoint >= 101
+ pipeWriter.open();
+ for (size_t i = 0; i < srcDocs.size(); ++i) {
+ if (i == failPoint) {
+ // Intentionally pushes invalid data at the fail point so that an error happens at
+ // the reader-side
+ pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize() / 2);
+ } else {
+ pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize());
+ }
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ bool errorOccurred = false;
+ try {
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt]);
+ ++resCnt;
+ }
+ } catch (const DBException& ex) {
+ errorOccurred = true;
+ ASSERT_NE(ex.code(), ErrorCodes::OK);
+ }
+ ASSERT_TRUE(errorOccurred);
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the getMore request for the aggregation results fails.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, KillCursorAfterAggRequest) {
+ // This 'nDocs' causes a cursor to be created for a simple scan aggregate command.
+ const auto nDocs = std::rand() % 100 + 102;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ // The first request.
+ BSONObj res;
+ ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res));
+
+ // Sanity checks of result.
+ ASSERT_EQ(res["ok"].Number(), 1.0);
+ ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch"));
+
+ // The default batch size is 101 and results can be returned through multiple batches. cursor.id
+ // != 0 means that a cursor is created.
+ auto cursorId = res["cursor"]["id"].Long();
+ ASSERT_TRUE(res["cursor"].Obj().hasField("id") && cursorId != 0);
+
+ // Kills the cursor.
+ auto killCursorCmdObj = BSON("killCursors"
+ << "coll"
+ << "cursors" << BSON_ARRAY(cursorId));
+ ASSERT_TRUE(client.runCommand(kDatabaseName, killCursorCmdObj.getOwned(), res));
+ ASSERT_EQ(res["ok"].Number(), 1.0);
+ auto cursorsKilled = res["cursorsKilled"].Array();
+ ASSERT_TRUE(cursorsKilled.size() == 1 && cursorsKilled[0].Long() == cursorId);
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the cursor for the aggregate request is killed.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, SimpleScanAndUnionWithMultipleSourcesAggRequest) {
+ const auto nDocs = std::rand() % 100 + 1;
+ std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs);
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter1("named_pipe1");
+ pw.notify();
+ pipeWriter1.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter1.close();
+
+ NamedPipeOutput pipeWriter2("named_pipe2");
+ pipeWriter2.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter2.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ // An aggregate request with a simple scan and $unionWith stage. $_externalDataSources option
+ // defines multiple data sources.
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll1",
+ pipeline: [{$unionWith: "coll2"}],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll1",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }, {
+ collName: "coll2",
+ dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ // Simple scan from 'coll1' first and then $unionWith from 'coll2'.
+ ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt % nDocs]);
+ ++resCnt;
+ }
+ ASSERT_EQ(resCnt, nDocs * 2);
+
+ auto explainAggCmdObj = fromjson(R"(
+{
+ aggregate: "coll1",
+ pipeline: [{$unionWith: "coll2"}],
+ explain: true,
+ $_externalDataSources: [{
+ collName: "coll1",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }, {
+ collName: "coll2",
+ dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the aggregation request is done.
+ verifyExplainAggCommand(client, explainAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, GroupAggRequest) {
+ std::vector<BSONObj> srcDocs = {
+ fromjson(R"(
+ {
+ "_id" : 1,
+ "item" : "a",
+ "quantity" : 2
+ })"),
+ fromjson(R"(
+ {
+ "_id" : 2,
+ "item" : "b",
+ "quantity" : 1
+ })"),
+ fromjson(R"(
+ {
+ "_id" : 3,
+ "item" : "a",
+ "quantity" : 5
+ })"),
+ fromjson(R"(
+ {
+ "_id" : 4,
+ "item" : "b",
+ "quantity" : 10
+ })"),
+ fromjson(R"(
+ {
+ "_id" : 5,
+ "item" : "c",
+ "quantity" : 10
+ })"),
+ };
+ PipeWaiter pw;
+
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter("named_pipe1");
+ pw.notify();
+ pipeWriter.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll",
+ pipeline: [{$group: {_id: "$item", o: {$sum: "$quantity"}}}],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+ std::vector<BSONObj> expectedRes = {
+ fromjson(R"(
+ {
+ "_id" : "a",
+ "o" : 7
+ })"),
+ fromjson(R"(
+ {
+ "_id" : "b",
+ "o" : 11
+ })"),
+ fromjson(R"(
+ {
+ "_id" : "c",
+ "o" : 10
+ })"),
+ };
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ // Result set is pretty small and so we use linear search of vector.
+ ASSERT_TRUE(
+ std::find_if(expectedRes.begin(), expectedRes.end(), [&](const BSONObj& expectedObj) {
+ return expectedObj.objsize() == doc.objsize() &&
+ std::memcmp(expectedObj.objdata(), doc.objdata(), expectedObj.objsize()) == 0;
+ }) != expectedRes.end());
+ ++resCnt;
+ }
+ ASSERT_EQ(resCnt, expectedRes.size());
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the aggregation request is done.
+ verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj);
+}
+
+TEST_F(ExternalDataSourceCommandsTest, LookupAggRequest) {
+ std::vector<BSONObj> srcDocs = {
+ fromjson(R"(
+ {
+ "a" : 1,
+ "data" : "abcd"
+ })"),
+ fromjson(R"(
+ {
+ "a" : 2,
+ "data" : "efgh"
+ })"),
+ fromjson(R"(
+ {
+ "a" : 3,
+ "data" : "ijkl"
+ })"),
+ };
+ PipeWaiter pw;
+
+ // For the $lookup stage, we need data to be available for both named pipes simultaneously
+ // because $lookup would read data from both collections and so we use two different named
+ // pipes and pushes data into the inner side first. To avoid racy condition, notify the reader
+ // side after both named pipes are created. This order is geared toward hash join algorithm.
+ stdx::thread producer([&] {
+ NamedPipeOutput pipeWriter2("named_pipe2");
+ NamedPipeOutput pipeWriter1("named_pipe1");
+ pw.notify();
+
+ // Pushes data into the inner side (== coll2 with named_pipe2) first because the hash join
+ // builds the inner (or build) side first.
+ pipeWriter2.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter2.close();
+
+ pipeWriter1.open();
+ for (auto&& srcDoc : srcDocs) {
+ pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize());
+ }
+ pipeWriter1.close();
+ });
+ ON_BLOCK_EXIT([&] { producer.join(); });
+
+ // Gives some time to the producer so that it can initialize a named pipe.
+ pw.wait();
+
+ DBDirectClient client(_opCtx);
+ auto aggCmdObj = fromjson(R"(
+{
+ aggregate: "coll1",
+ pipeline: [{$lookup: {from: "coll2", localField: "a", foreignField: "a", as: "o"}}],
+ cursor: {},
+ $_externalDataSources: [{
+ collName: "coll1",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }, {
+ collName: "coll2",
+ dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+ std::vector<BSONObj> expectedRes = {
+ fromjson(R"(
+ {
+ "a" : 1,
+ "data" : "abcd",
+ "o" : [{"a": 1, "data": "abcd"}]
+ })"),
+ fromjson(R"(
+ {
+ "a" : 2,
+ "data" : "efgh",
+ "o" : [{"a": 2, "data": "efgh"}]
+ })"),
+ fromjson(R"(
+ {
+ "a" : 3,
+ "data" : "ijkl",
+ "o" : [{"a": 3, "data": "ijkl"}]
+ })"),
+ };
+
+ auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj);
+ ASSERT_OK(swAggReq.getStatus());
+ auto swCursor = DBClientCursor::fromAggregationRequest(
+ &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false);
+ ASSERT_OK(swCursor.getStatus());
+
+ auto cursor = std::move(swCursor.getValue());
+ int resCnt = 0;
+ while (cursor->more()) {
+ auto doc = cursor->next();
+ // Result set is pretty small and so we use linear search of vector.
+ ASSERT_TRUE(
+ std::find_if(expectedRes.begin(), expectedRes.end(), [&](const BSONObj& expectedObj) {
+ return expectedObj.objsize() == doc.objsize() &&
+ std::memcmp(expectedObj.objdata(), doc.objdata(), expectedObj.objsize()) == 0;
+ }) != expectedRes.end());
+ ++resCnt;
+ }
+ ASSERT_EQ(resCnt, expectedRes.size());
+
+ auto explainAggCmdObj = fromjson(R"(
+{
+ aggregate: "coll1",
+ pipeline: [{$lookup: {from: "coll2", localField: "a", foreignField: "a", as: "o"}}],
+ explain: true,
+ $_externalDataSources: [{
+ collName: "coll1",
+ dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}]
+ }, {
+ collName: "coll2",
+ dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}]
+ }]
+}
+ )");
+
+ // The second explain request. This verifies that virtual collections are cleaned up after
+ // the aggregation request is done.
+ verifyExplainAggCommand(client, explainAggCmdObj);
+}
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/commands/external_data_source_scope_guard.cpp b/src/mongo/db/commands/external_data_source_scope_guard.cpp
new file mode 100644
index 00000000000..f9a29fcba67
--- /dev/null
+++ b/src/mongo/db/commands/external_data_source_scope_guard.cpp
@@ -0,0 +1,87 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/commands/external_data_source_scope_guard.h"
+
+#include "mongo/db/catalog/create_collection.h"
+#include "mongo/db/catalog/drop_collection.h"
+#include "mongo/db/catalog/virtual_collection_options.h"
+#include "mongo/db/drop_gen.h"
+#include "mongo/db/namespace_string.h"
+#include "mongo/db/pipeline/external_data_source_option_gen.h"
+#include "mongo/logv2/log.h"
+#include "mongo/util/destructor_guard.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery
+
+namespace mongo {
+ExternalDataSourceScopeGuard::ExternalDataSourceScopeGuard(
+ OperationContext* opCtx,
+ const std::vector<std::pair<NamespaceString, std::vector<ExternalDataSourceInfo>>>&
+ usedExternalDataSources)
+ : _opCtx(opCtx) {
+ // Just in case that any virtual collection could not be created, when dtor does not have a
+ // chance to be executed, cleans up collections that has already been created at that
+ // moment.
+ ScopeGuard dropVcollGuard([&] { dropVirtualCollections(); });
+
+ for (auto&& [extDataSourceNss, dataSources] : usedExternalDataSources) {
+ VirtualCollectionOptions vopts(dataSources);
+ uassertStatusOK(createVirtualCollection(opCtx, extDataSourceNss, vopts));
+ _toBeDroppedVirtualCollections.emplace_back(extDataSourceNss);
+ }
+
+ dropVcollGuard.dismiss();
+}
+
+void ExternalDataSourceScopeGuard::dropVirtualCollections() noexcept {
+ // The move constructor sets '_opCtx' to null when ownership is moved to the other object which
+ // means this object must not try to drop collections. There's nothing to drop if '_opCtx' is
+ // null.
+ if (!_opCtx) {
+ return;
+ }
+
+ // This function is called in a context of destructor or exception and so guard this against any
+ // exceptions.
+ DESTRUCTOR_GUARD({
+ for (auto&& nss : _toBeDroppedVirtualCollections) {
+ DropReply reply;
+ auto status =
+ dropCollection(_opCtx,
+ nss,
+ &reply,
+ DropCollectionSystemCollectionMode::kDisallowSystemCollectionDrops);
+ if (!status.isOK()) {
+ LOGV2_ERROR(6968700, "Failed to drop an external data source", "coll"_attr = nss);
+ }
+ }
+ });
+}
+} // namespace mongo
diff --git a/src/mongo/db/commands/external_data_source_scope_guard.h b/src/mongo/db/commands/external_data_source_scope_guard.h
new file mode 100644
index 00000000000..21b48890432
--- /dev/null
+++ b/src/mongo/db/commands/external_data_source_scope_guard.h
@@ -0,0 +1,84 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/clientcursor.h"
+#include "mongo/db/namespace_string.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/db/pipeline/external_data_source_option_gen.h"
+
+namespace mongo {
+/**
+ * This class makes sure that virtual collections that are created for external data sources are
+ * dropped when it's destroyed.
+ */
+class ExternalDataSourceScopeGuard {
+public:
+ // Makes ExternalDataSourceScopeGuard a decoration of ClientCursor.
+ static const ClientCursor::Decoration<std::shared_ptr<ExternalDataSourceScopeGuard>> get;
+
+ // Updates the operation context of decorated ExternalDataSourceScopeGuard object of 'cursor'
+ // so that it can drop virtual collections in the new 'opCtx'.
+ static void updateOperationContext(const ClientCursor* cursor, OperationContext* opCtx) {
+ if (auto self = get(cursor); self) {
+ get(cursor)->_opCtx = opCtx;
+ }
+ }
+
+ ExternalDataSourceScopeGuard() : _opCtx(nullptr), _toBeDroppedVirtualCollections() {}
+
+ ExternalDataSourceScopeGuard(
+ OperationContext* opCtx,
+ const std::vector<std::pair<NamespaceString, std::vector<ExternalDataSourceInfo>>>&
+ usedExternalDataSources);
+
+ // It does not make sense to support copy ctor because this object must drop created virtual
+ // collections.
+ ExternalDataSourceScopeGuard(const ExternalDataSourceScopeGuard&) = delete;
+
+ ExternalDataSourceScopeGuard(ExternalDataSourceScopeGuard&& other) noexcept
+ : _opCtx(other._opCtx),
+ _toBeDroppedVirtualCollections(std::move(other._toBeDroppedVirtualCollections)) {
+ // Ownership of created virtual collections are moved to this object and the other object
+ // must not try to drop them any more.
+ other._opCtx = nullptr;
+ }
+
+ ~ExternalDataSourceScopeGuard() {
+ dropVirtualCollections();
+ }
+
+private:
+ void dropVirtualCollections() noexcept;
+
+ OperationContext* _opCtx;
+ std::vector<NamespaceString> _toBeDroppedVirtualCollections;
+};
+} // namespace mongo
diff --git a/src/mongo/db/commands/fail_point_cmd.cpp b/src/mongo/db/commands/fail_point_cmd.cpp
index 7e868735f5b..a067522d6fe 100644
--- a/src/mongo/db/commands/fail_point_cmd.cpp
+++ b/src/mongo/db/commands/fail_point_cmd.cpp
@@ -93,6 +93,10 @@ public:
return Status::OK();
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
std::string help() const override {
return "modifies the settings of a fail point";
}
diff --git a/src/mongo/db/commands/find_and_modify.cpp b/src/mongo/db/commands/find_and_modify.cpp
index 1ba9efa0a75..ec70924da40 100644
--- a/src/mongo/db/commands/find_and_modify.cpp
+++ b/src/mongo/db/commands/find_and_modify.cpp
@@ -60,6 +60,7 @@
#include "mongo/db/repl/replication_coordinator.h"
#include "mongo/db/s/collection_sharding_state.h"
#include "mongo/db/s/operation_sharding_state.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/stats/counters.h"
#include "mongo/db/stats/resource_consumption_metrics.h"
#include "mongo/db/stats/top.h"
@@ -76,6 +77,7 @@
namespace mongo {
namespace {
+MONGO_FAIL_POINT_DEFINE(failAllFindAndModify);
MONGO_FAIL_POINT_DEFINE(hangBeforeFindAndModifyPerformsUpdate);
/**
@@ -682,6 +684,16 @@ write_ops::FindAndModifyCommandReply CmdFindAndModify::Invocation::typedRun(
}
}
+ if (analyze_shard_key::supportsPersistingSampledQueries() && request().getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addFindAndModifyQuery(request())
+ .getAsync([](auto) {});
+ }
+
+ if (MONGO_unlikely(failAllFindAndModify.shouldFail())) {
+ uasserted(ErrorCodes::InternalError, "failAllFindAndModify failpoint active!");
+ }
+
const bool inTransaction = opCtx->inMultiDocumentTransaction();
// Although usually the PlanExecutor handles WCE internally, it will throw WCEs when it
@@ -711,6 +723,7 @@ write_ops::FindAndModifyCommandReply CmdFindAndModify::Invocation::typedRun(
if (opCtx->getTxnNumber()) {
updateRequest.setStmtIds({stmtId});
}
+ updateRequest.setSampleId(req.getSampleId());
const ExtensionsCallbackReal extensionsCallback(
opCtx, &updateRequest.getNamespaceString());
diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp
index d048a2ec616..1188981c991 100644
--- a/src/mongo/db/commands/find_cmd.cpp
+++ b/src/mongo/db/commands/find_cmd.cpp
@@ -56,6 +56,7 @@
#include "mongo/db/query/get_executor.h"
#include "mongo/db/query/query_knobs_gen.h"
#include "mongo/db/repl/replication_coordinator.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/service_context.h"
#include "mongo/db/stats/counters.h"
#include "mongo/db/stats/resource_consumption_metrics.h"
@@ -65,6 +66,7 @@
#include "mongo/db/transaction/transaction_participant.h"
#include "mongo/logv2/log.h"
#include "mongo/rpc/get_status_from_command_result.h"
+#include "mongo/util/assert_util.h"
#include "mongo/util/database_name_util.h"
#include "mongo/util/fail_point.h"
@@ -362,6 +364,8 @@ public:
// execution tree with an EOFStage.
const auto& collection = ctx->getCollection();
+ cq->setUseCqfIfEligible(true);
+
// Get the execution plan for the query.
bool permitYield = true;
auto exec =
@@ -433,10 +437,10 @@ public:
// The presence of a term in the request indicates that this is an internal replication
// oplog read request.
if (term && isOplogNss) {
- // We do not want to take tickets for internal (replication) oplog reads. Stalling
- // on ticket acquisition can cause complicated deadlocks. Primaries may depend on
- // data reaching secondaries in order to proceed; and secondaries may get stalled
- // replicating because of an inability to acquire a read ticket.
+ // We do not want to wait to take tickets for internal (replication) oplog reads.
+ // Stalling on ticket acquisition can cause complicated deadlocks. Primaries may
+ // depend on data reaching secondaries in order to proceed; and secondaries may get
+ // stalled replicating because of an inability to acquire a read ticket.
opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate);
}
@@ -481,6 +485,16 @@ public:
.expectedUUID(findCommand->getCollectionUUID()));
const auto& nss = ctx->getNss();
+ if (analyze_shard_key::supportsPersistingSampledQueries() &&
+ findCommand->getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addFindQuery(*findCommand->getSampleId(),
+ nss,
+ findCommand->getFilter(),
+ findCommand->getCollation())
+ .getAsync([](auto) {});
+ }
+
// Going forward this operation must never ignore interrupt signals while waiting for
// lock acquisition. This InterruptibleLockGuard will ensure that waiting for lock
// re-acquisition after yielding will not ignore interrupt signals. This is necessary to
@@ -542,10 +556,10 @@ public:
const auto& findCommand = cq->getFindCommandRequest();
auto viewAggregationCommand =
uassertStatusOK(query_request_helper::asAggregationCommand(findCommand));
-
- BSONObj aggResult = CommandHelpers::runCommandDirectly(
- opCtx,
- OpMsgRequest::fromDBAndBody(_dbName.db(), std::move(viewAggregationCommand)));
+ auto aggRequest =
+ OpMsgRequestBuilder::create(_dbName, std::move(viewAggregationCommand));
+ aggRequest.validatedTenancyScope = _request.validatedTenancyScope;
+ BSONObj aggResult = CommandHelpers::runCommandDirectly(opCtx, aggRequest);
auto status = getStatusFromCommandResult(aggResult);
if (status.code() == ErrorCodes::InvalidPipelineOperator) {
uasserted(ErrorCodes::InvalidPipelineOperator,
@@ -572,6 +586,8 @@ public:
opCtx->recoveryUnit()->setReadOnce(true);
}
+ cq->setUseCqfIfEligible(true);
+
// Get the execution plan for the query.
bool permitYield = true;
auto exec =
diff --git a/src/mongo/db/commands/generic.cpp b/src/mongo/db/commands/generic.cpp
index c5f8dde7e74..9bc1717f6e8 100644
--- a/src/mongo/db/commands/generic.cpp
+++ b/src/mongo/db/commands/generic.cpp
@@ -173,6 +173,10 @@ public:
return false;
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
bool run(OperationContext* opCtx,
const DatabaseName&,
const BSONObj& cmdObj,
diff --git a/src/mongo/db/commands/getmore_cmd.cpp b/src/mongo/db/commands/getmore_cmd.cpp
index 3d6c73db66f..09aa73cc315 100644
--- a/src/mongo/db/commands/getmore_cmd.cpp
+++ b/src/mongo/db/commands/getmore_cmd.cpp
@@ -40,6 +40,7 @@
#include "mongo/db/client.h"
#include "mongo/db/clientcursor.h"
#include "mongo/db/commands.h"
+#include "mongo/db/commands/external_data_source_scope_guard.h"
#include "mongo/db/curop.h"
#include "mongo/db/curop_failpoint_helpers.h"
#include "mongo/db/cursor_manager.h"
@@ -485,6 +486,10 @@ public:
setUpOperationContextStateForGetMore(
opCtx, *cursorPin.getCursor(), _cmd, disableAwaitDataFailpointActive);
+ // Update opCtx of the decorated ExternalDataSourceScopeGuard object so that it can drop
+ // virtual collections in the new 'opCtx'.
+ ExternalDataSourceScopeGuard::updateOperationContext(cursorPin.getCursor(), opCtx);
+
// On early return, typically due to a failed assertion, delete the cursor.
ScopeGuard cursorDeleter([&] { cursorPin.deleteUnderlying(); });
@@ -717,10 +722,10 @@ public:
// internal clients (see checkAuthForGetMore).
curOp->debug().isReplOplogGetMore = true;
- // We do not want to take tickets for internal (replication) oplog reads. Stalling
- // on ticket acquisition can cause complicated deadlocks. Primaries may depend on
- // data reaching secondaries in order to proceed; and secondaries may get stalled
- // replicating because of an inability to acquire a read ticket.
+ // We do not want to wait to take tickets for internal (replication) oplog reads.
+ // Stalling on ticket acquisition can cause complicated deadlocks. Primaries may
+ // depend on data reaching secondaries in order to proceed; and secondaries may get
+ // stalled replicating because of an inability to acquire a read ticket.
opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate);
}
diff --git a/src/mongo/db/commands/parameters.cpp b/src/mongo/db/commands/parameters.cpp
index 0cfb0f6dcf2..0829fcd0e32 100644
--- a/src/mongo/db/commands/parameters.cpp
+++ b/src/mongo/db/commands/parameters.cpp
@@ -238,6 +238,11 @@ public:
h += "{ getParameter:'*' } or { getParameter:{allParameters: true} } to get everything\n";
return h;
}
+
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
bool run(OperationContext* opCtx,
const DatabaseName& dbName,
const BSONObj& cmdObj,
diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp
index b2fb0341568..13037602698 100644
--- a/src/mongo/db/commands/pipeline_command.cpp
+++ b/src/mongo/db/commands/pipeline_command.cpp
@@ -35,8 +35,8 @@
#include "mongo/db/auth/authorization_checks.h"
#include "mongo/db/auth/authorization_session.h"
#include "mongo/db/catalog/create_collection.h"
-#include "mongo/db/catalog/virtual_collection_options.h"
#include "mongo/db/commands.h"
+#include "mongo/db/commands/external_data_source_scope_guard.h"
#include "mongo/db/commands/run_aggregate.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/pipeline/aggregate_command_gen.h"
@@ -46,7 +46,7 @@
#include "mongo/db/pipeline/pipeline.h"
#include "mongo/db/query/query_knobs_gen.h"
#include "mongo/idl/idl_parser.h"
-#include "mongo/stdx/unordered_set.h"
+#include "mongo/util/assert_util.h"
#include "mongo/util/database_name_util.h"
namespace mongo {
@@ -202,18 +202,20 @@ public:
CommandHelpers::handleMarkKillOnClientDisconnect(
opCtx, !Pipeline::aggHasWriteStage(_request.body));
- // TODO SERVER-69687 Create a virtual collection per each used external data source.
- for (auto&& [collName, dataSources] : _usedExternalDataSources) {
- VirtualCollectionOptions vopts(dataSources);
- }
-
+ // Create virtual collections and drop them when aggregate command is done. Conceptually
+ // ownership of virtual collections are moved to runAggregate() function together with
+ // 'dropVcollGuard' so that it can clean up virtual collections when it's done with
+ // them. ExternalDataSourceScopeGuard will take care of the situation when any
+ // collection could not be created.
+ ExternalDataSourceScopeGuard dropVcollGuard(opCtx, _usedExternalDataSources);
uassertStatusOK(runAggregate(opCtx,
_aggregationRequest.getNamespace(),
_aggregationRequest,
_liteParsedPipeline,
_request.body,
_privileges,
- reply));
+ reply,
+ std::move(dropVcollGuard)));
// The aggregate command's response is unstable when 'explain' or 'exchange' fields are
// set.
@@ -231,14 +233,16 @@ public:
void explain(OperationContext* opCtx,
ExplainOptions::Verbosity verbosity,
rpc::ReplyBuilderInterface* result) override {
-
+ // See run() method for details.
+ ExternalDataSourceScopeGuard dropVcollGuard(opCtx, _usedExternalDataSources);
uassertStatusOK(runAggregate(opCtx,
_aggregationRequest.getNamespace(),
_aggregationRequest,
_liteParsedPipeline,
_request.body,
_privileges,
- result));
+ result,
+ std::move(dropVcollGuard)));
}
void doCheckAuthorization(OperationContext* opCtx) const override {
diff --git a/src/mongo/db/commands/profile_common.cpp b/src/mongo/db/commands/profile_common.cpp
index 54223b8f5a7..ba546e62067 100644
--- a/src/mongo/db/commands/profile_common.cpp
+++ b/src/mongo/db/commands/profile_common.cpp
@@ -133,6 +133,7 @@ bool ProfileCmdBase::run(OperationContext* opCtx,
newState.append("filter"_sd, newSettings.filter->serialize());
}
attrs.add("to", newState.obj());
+ attrs.add("db", dbName.db());
LOGV2(48742, "Profiler settings changed", attrs);
}
diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp
index eb32a6c220b..cb15cf82247 100644
--- a/src/mongo/db/commands/run_aggregate.cpp
+++ b/src/mongo/db/commands/run_aggregate.cpp
@@ -42,6 +42,7 @@
#include "mongo/db/change_stream_change_collection_manager.h"
#include "mongo/db/change_stream_pre_images_collection_manager.h"
#include "mongo/db/change_stream_serverless_helpers.h"
+#include "mongo/db/commands/external_data_source_scope_guard.h"
#include "mongo/db/curop.h"
#include "mongo/db/cursor_manager.h"
#include "mongo/db/db_raii.h"
@@ -81,6 +82,7 @@
#include "mongo/db/repl/speculative_majority_read_info.h"
#include "mongo/db/s/collection_sharding_state.h"
#include "mongo/db/s/operation_sharding_state.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/s/sharding_state.h"
#include "mongo/db/service_context.h"
#include "mongo/db/stats/resource_consumption_metrics.h"
@@ -659,7 +661,7 @@ Status runAggregate(OperationContext* opCtx,
const BSONObj& cmdObj,
const PrivilegeVector& privileges,
rpc::ReplyBuilderInterface* result) {
- return runAggregate(opCtx, nss, request, {request}, cmdObj, privileges, result);
+ return runAggregate(opCtx, nss, request, {request}, cmdObj, privileges, result, {});
}
Status runAggregate(OperationContext* opCtx,
@@ -668,7 +670,8 @@ Status runAggregate(OperationContext* opCtx,
const LiteParsedPipeline& liteParsedPipeline,
const BSONObj& cmdObj,
const PrivilegeVector& privileges,
- rpc::ReplyBuilderInterface* result) {
+ rpc::ReplyBuilderInterface* result,
+ ExternalDataSourceScopeGuard externalDataSourceGuard) {
// Perform some validations on the LiteParsedPipeline and request before continuing with the
// aggregation command.
@@ -945,6 +948,15 @@ Status runAggregate(OperationContext* opCtx,
}
}
+ if (analyze_shard_key::supportsPersistingSampledQueries() && request.getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addAggregateQuery(*request.getSampleId(),
+ expCtx->ns,
+ pipeline->getInitialQuery(),
+ expCtx->getCollatorBSON())
+ .getAsync([](auto) {});
+ }
+
// If the aggregate command supports encrypted collections, do rewrites of the pipeline to
// support querying against encrypted fields.
if (shouldDoFLERewrite(request)) {
@@ -1021,6 +1033,8 @@ Status runAggregate(OperationContext* opCtx,
p.deleteUnderlying();
}
});
+ auto extDataSrcGuard =
+ std::make_shared<ExternalDataSourceScopeGuard>(std::move(externalDataSourceGuard));
for (auto&& exec : execs) {
ClientCursorParams cursorParams(
std::move(exec),
@@ -1038,6 +1052,10 @@ Status runAggregate(OperationContext* opCtx,
pin->incNBatches();
cursors.emplace_back(pin.getCursor());
+ // All cursors share the ownership to 'extDataSrcGuard' and if the last cursor is destroyed,
+ // 'extDataSrcGuard' is also destroyed and created virtual collections are dropped by the
+ // destructor of ExternalDataSourceScopeGuard.
+ ExternalDataSourceScopeGuard::get(pin.getCursor()) = extDataSrcGuard;
pins.emplace_back(std::move(pin));
}
diff --git a/src/mongo/db/commands/run_aggregate.h b/src/mongo/db/commands/run_aggregate.h
index 0bb86ac91b0..33c03e80cdb 100644
--- a/src/mongo/db/commands/run_aggregate.h
+++ b/src/mongo/db/commands/run_aggregate.h
@@ -32,6 +32,7 @@
#include "mongo/bson/bsonobj.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/db/auth/privilege.h"
+#include "mongo/db/commands/external_data_source_scope_guard.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/pipeline/aggregate_command_gen.h"
@@ -57,7 +58,8 @@ Status runAggregate(OperationContext* opCtx,
const LiteParsedPipeline& liteParsedPipeline,
const BSONObj& cmdObj,
const PrivilegeVector& privileges,
- rpc::ReplyBuilderInterface* result);
+ rpc::ReplyBuilderInterface* result,
+ ExternalDataSourceScopeGuard externalDataSourceGuard);
/**
* Convenience version that internally constructs the LiteParsedPipeline.
diff --git a/src/mongo/db/commands/server_status_command.cpp b/src/mongo/db/commands/server_status_command.cpp
index 62433e57e62..521e4a49010 100644
--- a/src/mongo/db/commands/server_status_command.cpp
+++ b/src/mongo/db/commands/server_status_command.cpp
@@ -78,6 +78,10 @@ public:
return Status::OK();
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
bool run(OperationContext* opCtx,
const DatabaseName& dbName,
const BSONObj& cmdObj,
diff --git a/src/mongo/db/commands/tenant_migration_recipient_cmds.idl b/src/mongo/db/commands/tenant_migration_recipient_cmds.idl
index 96bfcd775cd..c8f58df0c0c 100644
--- a/src/mongo/db/commands/tenant_migration_recipient_cmds.idl
+++ b/src/mongo/db/commands/tenant_migration_recipient_cmds.idl
@@ -97,6 +97,7 @@ commands:
strict: true
namespace: ignored
api_version: ""
+ reply_type: recipientSyncDataResponse
inline_chained_structs: true
chained_structs:
MigrationRecipientCommonData: MigrationRecipientCommonData
@@ -144,6 +145,7 @@ commands:
recipientForgetMigration:
description: "Parser for the 'recipientForgetMigration' command."
command_name: recipientForgetMigration
+ reply_type: OkReply
strict: true
namespace: ignored
api_version: ""
diff --git a/src/mongo/db/concurrency/lock_state.h b/src/mongo/db/concurrency/lock_state.h
index 9d37c7e0c50..9dd80b25110 100644
--- a/src/mongo/db/concurrency/lock_state.h
+++ b/src/mongo/db/concurrency/lock_state.h
@@ -421,16 +421,15 @@ public:
};
/**
- * RAII-style class to set the priority for the ticket acquisition mechanism when acquiring a global
+ * RAII-style class to set the priority for the ticket admission mechanism when acquiring a global
* lock.
*/
-class SetTicketAquisitionPriorityForLock {
+class SetAdmissionPriorityForLock {
public:
- SetTicketAquisitionPriorityForLock(const SetTicketAquisitionPriorityForLock&) = delete;
- SetTicketAquisitionPriorityForLock& operator=(const SetTicketAquisitionPriorityForLock&) =
- delete;
- explicit SetTicketAquisitionPriorityForLock(OperationContext* opCtx,
- AdmissionContext::Priority priority)
+ SetAdmissionPriorityForLock(const SetAdmissionPriorityForLock&) = delete;
+ SetAdmissionPriorityForLock& operator=(const SetAdmissionPriorityForLock&) = delete;
+ explicit SetAdmissionPriorityForLock(OperationContext* opCtx,
+ AdmissionContext::Priority priority)
: _opCtx(opCtx), _originalPriority(opCtx->lockState()->getAdmissionPriority()) {
uassert(ErrorCodes::IllegalOperation,
"It is illegal for an operation to demote a high priority to a lower priority "
@@ -440,7 +439,7 @@ public:
_opCtx->lockState()->setAdmissionPriority(priority);
}
- ~SetTicketAquisitionPriorityForLock() {
+ ~SetAdmissionPriorityForLock() {
_opCtx->lockState()->setAdmissionPriority(_originalPriority);
}
diff --git a/src/mongo/db/concurrency/lock_state_test.cpp b/src/mongo/db/concurrency/lock_state_test.cpp
index 607c2c076f7..0062fe990ae 100644
--- a/src/mongo/db/concurrency/lock_state_test.cpp
+++ b/src/mongo/db/concurrency/lock_state_test.cpp
@@ -1244,20 +1244,19 @@ TEST_F(LockerImplTest, SetTicketAcquisitionForLockRAIIType) {
ASSERT_TRUE(opCtx->lockState()->shouldWaitForTicket());
{
- SetTicketAquisitionPriorityForLock setTicketAquisition(
- opCtx.get(), AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock setTicketAquisition(opCtx.get(),
+ AdmissionContext::Priority::kImmediate);
ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket());
}
ASSERT_TRUE(opCtx->lockState()->shouldWaitForTicket());
- // If ticket acquisitions are disabled on the lock state, the RAII type has no effect.
opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate);
ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket());
{
- SetTicketAquisitionPriorityForLock setTicketAquisition(
- opCtx.get(), AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock setTicketAquisition(opCtx.get(),
+ AdmissionContext::Priority::kImmediate);
ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket());
}
diff --git a/src/mongo/db/curop.cpp b/src/mongo/db/curop.cpp
index 73ddcb6827e..9e6b89c1af4 100644
--- a/src/mongo/db/curop.cpp
+++ b/src/mongo/db/curop.cpp
@@ -278,7 +278,7 @@ void CurOp::setGenericCursor_inlock(GenericCursor gc) {
void CurOp::_finishInit(OperationContext* opCtx, CurOpStack* stack) {
_stack = stack;
- _tickSource = SystemTickSource::get();
+ _tickSource = globalSystemTickSource();
if (opCtx) {
_stack->push(opCtx, this);
diff --git a/src/mongo/db/db_raii.cpp b/src/mongo/db/db_raii.cpp
index 5976628a7f1..7bc3b31f24b 100644
--- a/src/mongo/db/db_raii.cpp
+++ b/src/mongo/db/db_raii.cpp
@@ -38,6 +38,7 @@
#include "mongo/db/s/collection_sharding_state.h"
#include "mongo/db/s/operation_sharding_state.h"
#include "mongo/db/storage/snapshot_helper.h"
+#include "mongo/db/storage/storage_parameters_gen.h"
#include "mongo/logv2/log.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
@@ -641,7 +642,7 @@ AutoGetCollectionForRead::AutoGetCollectionForRead(OperationContext* opCtx,
}
}
-AutoGetCollectionForReadLockFree::EmplaceHelper::EmplaceHelper(
+AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper::EmplaceHelper(
OperationContext* opCtx,
CollectionCatalogStasher& catalogStasher,
const NamespaceStringOrUUID& nsOrUUID,
@@ -653,7 +654,7 @@ AutoGetCollectionForReadLockFree::EmplaceHelper::EmplaceHelper(
_options(std::move(options)),
_isLockFreeReadSubOperation(isLockFreeReadSubOperation) {}
-void AutoGetCollectionForReadLockFree::EmplaceHelper::emplace(
+void AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper::emplace(
boost::optional<AutoGetCollectionLockFree>& autoColl) const {
autoColl.emplace(
_opCtx,
@@ -706,7 +707,7 @@ void AutoGetCollectionForReadLockFree::EmplaceHelper::emplace(
_options);
}
-AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree(
+AutoGetCollectionForReadLockFreeLegacy::AutoGetCollectionForReadLockFreeLegacy(
OperationContext* opCtx,
const NamespaceStringOrUUID& nsOrUUID,
AutoGetCollection::Options options)
@@ -749,11 +750,22 @@ AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree(
options._secondaryNssOrUUIDs);
}
-AutoGetCollectionForReadMaybeLockFree::AutoGetCollectionForReadMaybeLockFree(
+AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree(
OperationContext* opCtx,
const NamespaceStringOrUUID& nsOrUUID,
AutoGetCollection::Options options) {
+ if (feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV()) {
+ _impl.emplace<AutoGetCollectionForReadLockFreePITCatalog>(
+ opCtx, nsOrUUID, std::move(options));
+ } else {
+ _impl.emplace<AutoGetCollectionForReadLockFreeLegacy>(opCtx, nsOrUUID, std::move(options));
+ }
+}
+AutoGetCollectionForReadMaybeLockFree::AutoGetCollectionForReadMaybeLockFree(
+ OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ AutoGetCollection::Options options) {
if (supportsLockFreeRead(opCtx)) {
_autoGetLockFree.emplace(opCtx, nsOrUUID, std::move(options));
} else {
@@ -808,7 +820,6 @@ AutoGetCollectionForReadCommandBase<AutoGetCollectionForReadType>::
_autoCollForRead.getNss().dbName()),
options._deadline,
options._secondaryNssOrUUIDs) {
-
hangBeforeAutoGetShardVersionCheck.executeIf(
[&](auto&) { hangBeforeAutoGetShardVersionCheck.pauseWhileSet(opCtx); },
[&](const BSONObj& data) {
@@ -1050,7 +1061,7 @@ BlockSecondaryReadsDuringBatchApplication_DONT_USE::
template class AutoGetCollectionForReadBase<AutoGetCollection, EmplaceAutoGetCollectionForRead>;
template class AutoGetCollectionForReadCommandBase<AutoGetCollectionForRead>;
template class AutoGetCollectionForReadBase<AutoGetCollectionLockFree,
- AutoGetCollectionForReadLockFree::EmplaceHelper>;
+ AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper>;
template class AutoGetCollectionForReadCommandBase<AutoGetCollectionForReadLockFree>;
} // namespace mongo
diff --git a/src/mongo/db/db_raii.h b/src/mongo/db/db_raii.h
index 6e5c2874ff5..0ae4b1c24bc 100644
--- a/src/mongo/db/db_raii.h
+++ b/src/mongo/db/db_raii.h
@@ -33,6 +33,8 @@
#include "mongo/db/catalog_raii.h"
#include "mongo/db/stats/top.h"
+#include "mongo/stdx/variant.h"
+#include "mongo/util/overloaded_visitor.h"
#include "mongo/util/timer.h"
namespace mongo {
@@ -195,24 +197,14 @@ private:
* Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks.
* Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a
* consistent in-memory and on-disk view of the storage catalog.
+ *
+ * This implementation does not use the PIT catalog.
*/
-class AutoGetCollectionForReadLockFree {
+class AutoGetCollectionForReadLockFreeLegacy {
public:
- AutoGetCollectionForReadLockFree(OperationContext* opCtx,
- const NamespaceStringOrUUID& nsOrUUID,
- AutoGetCollection::Options = {});
-
- explicit operator bool() const {
- return static_cast<bool>(getCollection());
- }
-
- const Collection* operator->() const {
- return getCollection().get();
- }
-
- const CollectionPtr& operator*() const {
- return getCollection();
- }
+ AutoGetCollectionForReadLockFreeLegacy(OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ AutoGetCollection::Options = {});
const CollectionPtr& getCollection() const {
return _autoGetCollectionForReadBase->getCollection();
@@ -226,11 +218,6 @@ public:
return _autoGetCollectionForReadBase->getNss();
}
- /**
- * Indicates whether any namespace in 'secondaryNssOrUUIDs' is a view or sharded.
- *
- * The secondary namespaces won't be checked if getCollection() returns nullptr.
- */
bool isAnySecondaryNamespaceAViewOrSharded() const {
return _secondaryNssIsAViewOrSharded;
}
@@ -272,6 +259,112 @@ private:
};
/**
+ * Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks.
+ * Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a
+ * consistent in-memory and on-disk view of the storage catalog.
+ *
+ * This implementation uses the point-in-time (PIT) catalog.
+ */
+class AutoGetCollectionForReadLockFreePITCatalog {
+public:
+ AutoGetCollectionForReadLockFreePITCatalog(OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ AutoGetCollection::Options options = {})
+ : _impl(opCtx, nsOrUUID, std::move(options)) {}
+
+ const CollectionPtr& getCollection() const {
+ return _impl.getCollection();
+ }
+
+ const ViewDefinition* getView() const {
+ return _impl.getView();
+ }
+
+ const NamespaceString& getNss() const {
+ return _impl.getNss();
+ }
+
+ bool isAnySecondaryNamespaceAViewOrSharded() const {
+ return _impl.isAnySecondaryNamespaceAViewOrSharded();
+ }
+
+private:
+ // TODO (SERVER-68271): Replace with new implementation using the PIT catalog.
+ AutoGetCollectionForReadLockFreeLegacy _impl;
+};
+
+/**
+ * Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks.
+ * Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a
+ * consistent in-memory and on-disk view of the storage catalog.
+ */
+class AutoGetCollectionForReadLockFree {
+public:
+ AutoGetCollectionForReadLockFree(OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ AutoGetCollection::Options = {});
+
+ explicit operator bool() const {
+ return static_cast<bool>(getCollection());
+ }
+
+ const Collection* operator->() const {
+ return getCollection().get();
+ }
+
+ const CollectionPtr& operator*() const {
+ return getCollection();
+ }
+
+ const CollectionPtr& getCollection() const {
+ return stdx::visit(
+ OverloadedVisitor{
+ [](auto&& impl) -> const CollectionPtr& { return impl.getCollection(); },
+ [](stdx::monostate) -> const CollectionPtr& { MONGO_UNREACHABLE; },
+ },
+ _impl);
+ }
+
+ const ViewDefinition* getView() const {
+ return stdx::visit(
+ OverloadedVisitor{[](auto&& impl) { return impl.getView(); },
+ [](stdx::monostate) -> const ViewDefinition* { MONGO_UNREACHABLE; }},
+ _impl);
+ }
+
+ const NamespaceString& getNss() const {
+ return stdx::visit(
+ OverloadedVisitor{[](auto&& impl) -> const NamespaceString& { return impl.getNss(); },
+ [](stdx::monostate) -> const NamespaceString& { MONGO_UNREACHABLE; }},
+ _impl);
+ }
+
+ /**
+ * Indicates whether any namespace in 'secondaryNssOrUUIDs' is a view or sharded.
+ *
+ * The secondary namespaces won't be checked if getCollection() returns nullptr.
+ */
+ bool isAnySecondaryNamespaceAViewOrSharded() const {
+ return stdx::visit(
+ OverloadedVisitor{
+ [](auto&& impl) { return impl.isAnySecondaryNamespaceAViewOrSharded(); },
+ [](stdx::monostate) -> bool { MONGO_UNREACHABLE; }},
+ _impl);
+ }
+
+private:
+ // If the gPointInTimeCatalogLookups feature flag is enabled, this will contain an instance of
+ // AutoGetCollectionForReadLockFreePITCatalog. Otherwise, it will contain an instance of
+ // AutoGetCollectionForReadLockFreeLegacy. Note that stdx::monostate is required for default
+ // construction, since these other types are not movable, but after construction, the value
+ // should never be set to stdx::monostate.
+ stdx::variant<stdx::monostate,
+ AutoGetCollectionForReadLockFreeLegacy,
+ AutoGetCollectionForReadLockFreePITCatalog>
+ _impl;
+};
+
+/**
* Creates either an AutoGetCollectionForRead or AutoGetCollectionForReadLockFree depending on
* whether a lock-free read is supported.
*/
diff --git a/src/mongo/db/dbdirectclient.h b/src/mongo/db/dbdirectclient.h
index 315c17a8bdc..c0f5bf23400 100644
--- a/src/mongo/db/dbdirectclient.h
+++ b/src/mongo/db/dbdirectclient.h
@@ -55,6 +55,7 @@ public:
using DBClientBase::find;
using DBClientBase::insert;
using DBClientBase::remove;
+ using DBClientBase::runCommand;
using DBClientBase::update;
std::unique_ptr<DBClientCursor> find(FindCommandRequest findRequest,
diff --git a/src/mongo/db/exec/batched_delete_stage.cpp b/src/mongo/db/exec/batched_delete_stage.cpp
index b12a5c0d73b..7c1141e0f99 100644
--- a/src/mongo/db/exec/batched_delete_stage.cpp
+++ b/src/mongo/db/exec/batched_delete_stage.cpp
@@ -325,7 +325,8 @@ long long BatchedDeleteStage::_commitBatch(WorkingSetID* out,
// Start a WUOW with 'groupOplogEntries' which groups a delete batch into a single timestamp
// and oplog entry.
- WriteUnitOfWork wuow(opCtx(), true /* groupOplogEntries */);
+ WriteUnitOfWork wuow(opCtx(),
+ _stagedDeletesBuffer.size() > 1U ? true : false /* groupOplogEntries */);
for (; *bufferOffset < _stagedDeletesBuffer.size(); ++*bufferOffset) {
if (MONGO_unlikely(throwWriteConflictExceptionInBatchedDeleteStage.shouldFail())) {
throwWriteConflictException(
diff --git a/src/mongo/db/exec/bucket_unpacker.cpp b/src/mongo/db/exec/bucket_unpacker.cpp
index d2ad79f0e03..6a819181df9 100644
--- a/src/mongo/db/exec/bucket_unpacker.cpp
+++ b/src/mongo/db/exec/bucket_unpacker.cpp
@@ -41,6 +41,7 @@
#include "mongo/db/matcher/expression_parser.h"
#include "mongo/db/matcher/expression_tree.h"
#include "mongo/db/matcher/extensions_callback_noop.h"
+#include "mongo/db/matcher/rewrite_expr.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/timeseries/timeseries_options.h"
@@ -130,9 +131,9 @@ std::unique_ptr<MatchExpression> makeOr(std::vector<std::unique_ptr<MatchExpress
return std::make_unique<OrMatchExpression>(std::move(nontrivial));
}
-std::unique_ptr<MatchExpression> handleIneligible(IneligiblePredicatePolicy policy,
- const MatchExpression* matchExpr,
- StringData message) {
+BucketSpec::BucketPredicate handleIneligible(IneligiblePredicatePolicy policy,
+ const MatchExpression* matchExpr,
+ StringData message) {
switch (policy) {
case IneligiblePredicatePolicy::kError:
uasserted(
@@ -140,7 +141,7 @@ std::unique_ptr<MatchExpression> handleIneligible(IneligiblePredicatePolicy poli
"Error translating non-metadata time-series predicate to operate on buckets: " +
message + ": " + matchExpr->serialize().toString());
case IneligiblePredicatePolicy::kIgnore:
- return nullptr;
+ return {};
}
MONGO_UNREACHABLE_TASSERT(5916307);
}
@@ -204,40 +205,32 @@ std::unique_ptr<MatchExpression> createTypeEqualityPredicate(
return makeOr(std::move(typeEqualityPredicates));
}
-std::unique_ptr<MatchExpression> createComparisonPredicate(
- const ComparisonMatchExpressionBase* matchExpr,
+boost::optional<StringData> checkComparisonPredicateErrors(
+ const MatchExpression* matchExpr,
+ const StringData matchExprPath,
+ const BSONElement& matchExprData,
const BucketSpec& bucketSpec,
- int bucketMaxSpanSeconds,
- ExpressionContext::CollationMatchesDefault collationMatchesDefault,
- boost::intrusive_ptr<ExpressionContext> pExpCtx,
- bool haveComputedMetaField,
- bool includeMetaField,
- bool assumeNoMixedSchemaData,
- IneligiblePredicatePolicy policy) {
+ ExpressionContext::CollationMatchesDefault collationMatchesDefault) {
using namespace timeseries;
- const auto matchExprPath = matchExpr->path();
- const auto matchExprData = matchExpr->getData();
-
// The control field's min and max are chosen using a field-order insensitive comparator, while
// MatchExpressions use a comparator that treats field-order as significant. Because of this we
// will not perform this optimization on queries with operands of compound types.
if (matchExprData.type() == BSONType::Object || matchExprData.type() == BSONType::Array)
- return handleIneligible(policy, matchExpr, "operand can't be an object or array"_sd);
+ return "operand can't be an object or array"_sd;
// MatchExpressions have special comparison semantics regarding null, in that {$eq: null} will
// match all documents where the field is either null or missing. Because this is different
// from both the comparison semantics that InternalExprComparison expressions and the control's
// min and max fields use, we will not perform this optimization on queries with null operands.
if (matchExprData.type() == BSONType::jstNULL)
- return handleIneligible(policy, matchExpr, "can't handle {$eq: null}"_sd);
+ return "can't handle {$eq: null}"_sd;
// The control field's min and max are chosen based on the collation of the collection. If the
// query's collation does not match the collection's collation and the query operand is a
// string or compound type (skipped above) we will not perform this optimization.
if (collationMatchesDefault == ExpressionContext::CollationMatchesDefault::kNo &&
matchExprData.type() == BSONType::String) {
- return handleIneligible(
- policy, matchExpr, "can't handle string comparison with a non-default collation"_sd);
+ return "can't handle string comparison with a non-default collation"_sd;
}
// This function only handles time and measurement predicates--not metadata.
@@ -252,19 +245,45 @@ std::unique_ptr<MatchExpression> createComparisonPredicate(
// We must avoid mapping predicates on fields computed via $addFields or a computed $project.
if (bucketSpec.fieldIsComputed(matchExprPath.toString())) {
- return handleIneligible(policy, matchExpr, "can't handle a computed field");
+ return "can't handle a computed field"_sd;
}
const auto isTimeField = (matchExprPath == bucketSpec.timeField());
if (isTimeField && matchExprData.type() != BSONType::Date) {
// Users are not allowed to insert non-date measurements into time field. So this query
// would not match anything. We do not need to optimize for this case.
- return handleIneligible(
- policy,
- matchExpr,
- "This predicate will never be true, because the time field always contains a Date");
+ return "This predicate will never be true, because the time field always contains a Date"_sd;
+ }
+
+ return boost::none;
+}
+
+std::unique_ptr<MatchExpression> createComparisonPredicate(
+ const ComparisonMatchExpressionBase* matchExpr,
+ const BucketSpec& bucketSpec,
+ int bucketMaxSpanSeconds,
+ ExpressionContext::CollationMatchesDefault collationMatchesDefault,
+ boost::intrusive_ptr<ExpressionContext> pExpCtx,
+ bool haveComputedMetaField,
+ bool includeMetaField,
+ bool assumeNoMixedSchemaData,
+ IneligiblePredicatePolicy policy) {
+ using namespace timeseries;
+ const auto matchExprPath = matchExpr->path();
+ const auto matchExprData = matchExpr->getData();
+
+ const auto error = checkComparisonPredicateErrors(
+ matchExpr, matchExprPath, matchExprData, bucketSpec, collationMatchesDefault);
+ if (error) {
+ return handleIneligible(policy, matchExpr, *error).loosePredicate;
}
+ const auto isTimeField = (matchExprPath == bucketSpec.timeField());
+ auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath;
+ const StringData minPathStringData(minPath);
+ auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath;
+ const StringData maxPathStringData(maxPath);
+
BSONObj minTime;
BSONObj maxTime;
if (isTimeField) {
@@ -273,11 +292,6 @@ std::unique_ptr<MatchExpression> createComparisonPredicate(
maxTime = BSON("" << timeField + Seconds(bucketMaxSpanSeconds));
}
- const auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath;
- const StringData minPathStringData(minPath);
- const auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath;
- const StringData maxPathStringData(maxPath);
-
switch (matchExpr->matchType()) {
case MatchExpression::EQ:
case MatchExpression::INTERNAL_EXPR_EQ:
@@ -481,9 +495,108 @@ std::unique_ptr<MatchExpression> createComparisonPredicate(
MONGO_UNREACHABLE_TASSERT(5348303);
}
+std::unique_ptr<MatchExpression> createTightComparisonPredicate(
+ const ComparisonMatchExpressionBase* matchExpr,
+ const BucketSpec& bucketSpec,
+ ExpressionContext::CollationMatchesDefault collationMatchesDefault) {
+ using namespace timeseries;
+ const auto matchExprPath = matchExpr->path();
+ const auto matchExprData = matchExpr->getData();
+
+ const auto error = checkComparisonPredicateErrors(
+ matchExpr, matchExprPath, matchExprData, bucketSpec, collationMatchesDefault);
+ if (error) {
+ return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore, matchExpr, *error)
+ .loosePredicate;
+ }
+
+ // We have to disable the tight predicate for the measurement field. There might be missing
+ // values in the measurements and the control fields ignore them on insertion. So we cannot use
+ // bucket min and max to determine the property of all events in the bucket. For measurement
+ // fields, there's a further problem that if the control field is an array, we cannot generate
+ // the tight predicate because the predicate will be implicitly mapped over the array elements.
+ if (matchExprPath != bucketSpec.timeField()) {
+ return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore,
+ matchExpr,
+ "can't create tight predicate on non-time field")
+ .tightPredicate;
+ }
+
+ auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath;
+ const StringData minPathStringData(minPath);
+ auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath;
+ const StringData maxPathStringData(maxPath);
+
+ switch (matchExpr->matchType()) {
+ // All events satisfy $eq if bucket min and max both satisfy $eq.
+ case MatchExpression::EQ:
+ return makePredicate(
+ MatchExprPredicate<EqualityMatchExpression>(minPathStringData, matchExprData),
+ MatchExprPredicate<EqualityMatchExpression>(maxPathStringData, matchExprData));
+ case MatchExpression::INTERNAL_EXPR_EQ:
+ return makePredicate(
+ MatchExprPredicate<InternalExprEqMatchExpression>(minPathStringData, matchExprData),
+ MatchExprPredicate<InternalExprEqMatchExpression>(maxPathStringData,
+ matchExprData));
+
+ // All events satisfy $gt if bucket min satisfy $gt.
+ case MatchExpression::GT:
+ return std::make_unique<GTMatchExpression>(minPathStringData, matchExprData);
+ case MatchExpression::INTERNAL_EXPR_GT:
+ return std::make_unique<InternalExprGTMatchExpression>(minPathStringData,
+ matchExprData);
+
+ // All events satisfy $gte if bucket min satisfy $gte.
+ case MatchExpression::GTE:
+ return std::make_unique<GTEMatchExpression>(minPathStringData, matchExprData);
+ case MatchExpression::INTERNAL_EXPR_GTE:
+ return std::make_unique<InternalExprGTEMatchExpression>(minPathStringData,
+ matchExprData);
+
+ // All events satisfy $lt if bucket max satisfy $lt.
+ case MatchExpression::LT:
+ return std::make_unique<LTMatchExpression>(maxPathStringData, matchExprData);
+ case MatchExpression::INTERNAL_EXPR_LT:
+ return std::make_unique<InternalExprLTMatchExpression>(maxPathStringData,
+ matchExprData);
+
+ // All events satisfy $lte if bucket max satisfy $lte.
+ case MatchExpression::LTE:
+ return std::make_unique<LTEMatchExpression>(maxPathStringData, matchExprData);
+ case MatchExpression::INTERNAL_EXPR_LTE:
+ return std::make_unique<InternalExprLTEMatchExpression>(maxPathStringData,
+ matchExprData);
+
+ default:
+ MONGO_UNREACHABLE_TASSERT(7026901);
+ }
+}
+
+std::unique_ptr<MatchExpression> createTightExprComparisonPredicate(
+ const ExprMatchExpression* matchExpr,
+ const BucketSpec& bucketSpec,
+ ExpressionContext::CollationMatchesDefault collationMatchesDefault,
+ boost::intrusive_ptr<ExpressionContext> pExpCtx) {
+ using namespace timeseries;
+ auto rewriteMatchExpr = RewriteExpr::rewrite(matchExpr->getExpression(), pExpCtx->getCollator())
+ .releaseMatchExpression();
+ if (rewriteMatchExpr &&
+ ComparisonMatchExpressionBase::isInternalExprComparison(rewriteMatchExpr->matchType())) {
+ auto compareMatchExpr =
+ checked_cast<const ComparisonMatchExpressionBase*>(rewriteMatchExpr.get());
+ return createTightComparisonPredicate(
+ compareMatchExpr, bucketSpec, collationMatchesDefault);
+ }
+
+ return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore,
+ matchExpr,
+ "can't handle non-comparison $expr match expression")
+ .tightPredicate;
+}
+
} // namespace
-std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
+BucketSpec::BucketPredicate BucketSpec::createPredicatesOnBucketLevelField(
const MatchExpression* matchExpr,
const BucketSpec& bucketSpec,
int bucketMaxSpanSeconds,
@@ -516,39 +629,61 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
if (!includeMetaField)
return handleIneligible(policy, matchExpr, "cannot handle an excluded meta field");
- auto result = matchExpr->shallowClone();
+ auto looseResult = matchExpr->shallowClone();
expression::applyRenamesToExpression(
- result.get(),
+ looseResult.get(),
{{bucketSpec.metaField().value(), timeseries::kBucketMetaFieldName.toString()}});
- return result;
+ auto tightResult = looseResult->shallowClone();
+ return {std::move(looseResult), std::move(tightResult)};
}
if (matchExpr->matchType() == MatchExpression::AND) {
auto nextAnd = static_cast<const AndMatchExpression*>(matchExpr);
- auto andMatchExpr = std::make_unique<AndMatchExpression>();
-
+ auto looseAndExpression = std::make_unique<AndMatchExpression>();
+ auto tightAndExpression = std::make_unique<AndMatchExpression>();
for (size_t i = 0; i < nextAnd->numChildren(); i++) {
- if (auto child = createPredicatesOnBucketLevelField(nextAnd->getChild(i),
- bucketSpec,
- bucketMaxSpanSeconds,
- collationMatchesDefault,
- pExpCtx,
- haveComputedMetaField,
- includeMetaField,
- assumeNoMixedSchemaData,
- policy)) {
- andMatchExpr->add(std::move(child));
+ auto child = createPredicatesOnBucketLevelField(nextAnd->getChild(i),
+ bucketSpec,
+ bucketMaxSpanSeconds,
+ collationMatchesDefault,
+ pExpCtx,
+ haveComputedMetaField,
+ includeMetaField,
+ assumeNoMixedSchemaData,
+ policy);
+ if (child.loosePredicate) {
+ looseAndExpression->add(std::move(child.loosePredicate));
+ }
+
+ if (tightAndExpression && child.tightPredicate) {
+ tightAndExpression->add(std::move(child.tightPredicate));
+ } else {
+ // For tight expression, null means always false, we can short circuit here.
+ tightAndExpression = nullptr;
}
}
- if (andMatchExpr->numChildren() == 1) {
- return andMatchExpr->releaseChild(0);
+
+ // For a loose predicate, if we are unable to generate an expression we can just treat it as
+ // always true or an empty AND. This is because we are trying to generate a predicate that
+ // will match the superset of our actual results.
+ std::unique_ptr<MatchExpression> looseExpression = nullptr;
+ if (looseAndExpression->numChildren() == 1) {
+ looseExpression = looseAndExpression->releaseChild(0);
+ } else if (looseAndExpression->numChildren() > 1) {
+ looseExpression = std::move(looseAndExpression);
}
- if (andMatchExpr->numChildren() > 0) {
- return andMatchExpr;
+
+ // For a tight predicate, if we are unable to generate an expression we can just treat it as
+ // always false. This is because we are trying to generate a predicate that will match the
+ // subset of our actual results.
+ std::unique_ptr<MatchExpression> tightExpression = nullptr;
+ if (tightAndExpression && tightAndExpression->numChildren() == 1) {
+ tightExpression = tightAndExpression->releaseChild(0);
+ } else {
+ tightExpression = std::move(tightAndExpression);
}
- // No error message here: an empty AND is valid.
- return nullptr;
+ return {std::move(looseExpression), std::move(tightExpression)};
} else if (matchExpr->matchType() == MatchExpression::OR) {
// Given {$or: [A, B]}, suppose A, B can be pushed down as A', B'.
// If an event matches {$or: [A, B]} then either:
@@ -556,9 +691,9 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
// - it matches B, which means any bucket containing it matches B'
// So {$or: [A', B']} will capture all the buckets we need to satisfy {$or: [A, B]}.
auto nextOr = static_cast<const OrMatchExpression*>(matchExpr);
- auto result = std::make_unique<OrMatchExpression>();
+ auto looseOrExpression = std::make_unique<OrMatchExpression>();
+ auto tightOrExpression = std::make_unique<OrMatchExpression>();
- bool alwaysTrue = false;
for (size_t i = 0; i < nextOr->numChildren(); i++) {
auto child = createPredicatesOnBucketLevelField(nextOr->getChild(i),
bucketSpec,
@@ -569,41 +704,76 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
includeMetaField,
assumeNoMixedSchemaData,
policy);
- if (child) {
- result->add(std::move(child));
+ if (looseOrExpression && child.loosePredicate) {
+ looseOrExpression->add(std::move(child.loosePredicate));
} else {
- // Since this argument is always-true, the entire OR is always-true.
- alwaysTrue = true;
+ // For loose expression, null means always true, we can short circuit here.
+ looseOrExpression = nullptr;
+ }
- // Only short circuit if we're uninterested in reporting errors.
- if (policy == IneligiblePredicatePolicy::kIgnore)
- break;
+ // For tight predicate, we give a tighter bound so that all events in the bucket
+ // either all matches A or all matches B.
+ if (child.tightPredicate) {
+ tightOrExpression->add(std::move(child.tightPredicate));
}
}
- if (alwaysTrue)
- return nullptr;
- // No special case for an empty OR: returning nullptr would be incorrect because it
- // means 'always-true', here.
- return result;
+ // For a loose predicate, if we are unable to generate an expression we can just treat it as
+ // always true. This is because we are trying to generate a predicate that will match the
+ // superset of our actual results.
+ std::unique_ptr<MatchExpression> looseExpression = nullptr;
+ if (looseOrExpression && looseOrExpression->numChildren() == 1) {
+ looseExpression = looseOrExpression->releaseChild(0);
+ } else {
+ looseExpression = std::move(looseOrExpression);
+ }
+
+ // For a tight predicate, if we are unable to generate an expression we can just treat it as
+ // always false or an empty OR. This is because we are trying to generate a predicate that
+ // will match the subset of our actual results.
+ std::unique_ptr<MatchExpression> tightExpression = nullptr;
+ if (tightOrExpression->numChildren() == 1) {
+ tightExpression = tightOrExpression->releaseChild(0);
+ } else if (tightOrExpression->numChildren() > 1) {
+ tightExpression = std::move(tightOrExpression);
+ }
+
+ return {std::move(looseExpression), std::move(tightExpression)};
} else if (ComparisonMatchExpression::isComparisonMatchExpression(matchExpr) ||
ComparisonMatchExpressionBase::isInternalExprComparison(matchExpr->matchType())) {
- return createComparisonPredicate(
- checked_cast<const ComparisonMatchExpressionBase*>(matchExpr),
- bucketSpec,
- bucketMaxSpanSeconds,
- collationMatchesDefault,
- pExpCtx,
- haveComputedMetaField,
- includeMetaField,
- assumeNoMixedSchemaData,
- policy);
+ return {
+ createComparisonPredicate(checked_cast<const ComparisonMatchExpressionBase*>(matchExpr),
+ bucketSpec,
+ bucketMaxSpanSeconds,
+ collationMatchesDefault,
+ pExpCtx,
+ haveComputedMetaField,
+ includeMetaField,
+ assumeNoMixedSchemaData,
+ policy),
+ createTightComparisonPredicate(
+ checked_cast<const ComparisonMatchExpressionBase*>(matchExpr),
+ bucketSpec,
+ collationMatchesDefault)};
+ } else if (matchExpr->matchType() == MatchExpression::EXPRESSION) {
+ return {
+ // The loose predicate will be pushed before the unpacking which will be inspected by
+ // the
+ // query planner. Since the classic planner doesn't handle the $expr expression, we
+ // don't
+ // generate the loose predicate.
+ nullptr,
+ createTightExprComparisonPredicate(checked_cast<const ExprMatchExpression*>(matchExpr),
+ bucketSpec,
+ collationMatchesDefault,
+ pExpCtx)};
} else if (matchExpr->matchType() == MatchExpression::GEO) {
auto& geoExpr = static_cast<const GeoMatchExpression*>(matchExpr)->getGeoExpression();
if (geoExpr.getPred() == GeoExpression::WITHIN ||
geoExpr.getPred() == GeoExpression::INTERSECT) {
- return std::make_unique<InternalBucketGeoWithinMatchExpression>(
- geoExpr.getGeometryPtr(), geoExpr.getField());
+ return {std::make_unique<InternalBucketGeoWithinMatchExpression>(
+ geoExpr.getGeometryPtr(), geoExpr.getField()),
+ nullptr};
}
} else if (matchExpr->matchType() == MatchExpression::EXISTS) {
if (assumeNoMixedSchemaData) {
@@ -613,7 +783,7 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
std::string{timeseries::kControlMinFieldNamePrefix} + matchExpr->path())));
result->add(std::make_unique<ExistsMatchExpression>(StringData(
std::string{timeseries::kControlMaxFieldNamePrefix} + matchExpr->path())));
- return result;
+ return {std::move(result), nullptr};
} else {
// At time of writing, we only pass 'kError' when creating a partial index, and
// we know the collection will have no mixed-schema buckets by the time the index is
@@ -622,7 +792,7 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
"Can't push down {$exists: true} when the collection may have mixed-schema "
"buckets.",
policy != IneligiblePredicatePolicy::kError);
- return nullptr;
+ return {};
}
} else if (matchExpr->matchType() == MatchExpression::MATCH_IN) {
// {a: {$in: [X, Y]}} is equivalent to {$or: [ {a: X}, {a: Y} ]}.
@@ -664,11 +834,11 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField(
}
}
if (alwaysTrue)
- return nullptr;
+ return {};
// As above, no special case for an empty IN: returning nullptr would be incorrect because
// it means 'always-true', here.
- return result;
+ return {std::move(result), nullptr};
}
return handleIneligible(policy, matchExpr, "can't handle this predicate");
}
@@ -713,9 +883,9 @@ std::pair<bool, BSONObj> BucketSpec::pushdownPredicate(
BucketSpec{
tsOptions.getTimeField().toString(),
metaField.map([](StringData s) { return s.toString(); }),
- // Since we are operating on a collection, not a query-result, there are no
- // inclusion/exclusion projections we need to apply to the buckets before
- // unpacking.
+ // Since we are operating on a collection, not a query-result,
+ // there are no inclusion/exclusion projections we need to apply
+ // to the buckets before unpacking.
{},
// And there are no computed projections.
{},
@@ -727,6 +897,7 @@ std::pair<bool, BSONObj> BucketSpec::pushdownPredicate(
includeMetaField,
assumeNoMixedSchemaData,
policy)
+ .loosePredicate
: nullptr;
BSONObjBuilder result;
@@ -1230,9 +1401,10 @@ Document BucketUnpacker::extractSingleMeasurement(int j) {
return measurement.freeze();
}
-void BucketUnpacker::reset(BSONObj&& bucket) {
+void BucketUnpacker::reset(BSONObj&& bucket, bool bucketMatchedQuery) {
_unpackingImpl.reset();
_bucket = std::move(bucket);
+ _bucketMatchedQuery = bucketMatchedQuery;
uassert(5346510, "An empty bucket cannot be unpacked", !_bucket.isEmpty());
auto&& dataRegion = _bucket.getField(timeseries::kBucketDataFieldName).Obj();
diff --git a/src/mongo/db/exec/bucket_unpacker.h b/src/mongo/db/exec/bucket_unpacker.h
index 8f7f8210618..3a9813eb87a 100644
--- a/src/mongo/db/exec/bucket_unpacker.h
+++ b/src/mongo/db/exec/bucket_unpacker.h
@@ -127,14 +127,29 @@ public:
kError,
};
+ struct BucketPredicate {
+ // A loose predicate is a predicate which returns true when any measures of a bucket
+ // matches.
+ std::unique_ptr<MatchExpression> loosePredicate;
+
+ // A tight predicate is a predicate which returns true when all measures of a bucket
+ // matches.
+ std::unique_ptr<MatchExpression> tightPredicate;
+ };
+
/**
* Takes a predicate after $_internalUnpackBucket on a bucketed field as an argument and
- * attempts to map it to a new predicate on the 'control' field. For example, the predicate
- * {a: {$gt: 5}} will generate the predicate {control.max.a: {$_internalExprGt: 5}}, which will
- * be added before the $_internalUnpackBucket stage.
+ * attempts to map it to new predicates on the 'control' field. There will be a 'loose'
+ * predicate that will match if some of the event field matches, also a 'tight' predicate that
+ * will match if all of the event field matches. For example, the event level predicate {a:
+ * {$gt: 5}} will generate the loose predicate {control.max.a: {$_internalExprGt: 5}}, and the
+ * tight predicate {control.min.a: {$_internalExprGt: 5}}. The loose predicate will be added
+ * before the
+ * $_internalUnpackBucket stage to filter out buckets with no match. The tight predicate will
+ * be used to evaluate predicate on bucket level to avoid unnecessary event level evaluation.
*
- * If the original predicate is on the bucket's timeField we may also create a new predicate
- * on the '_id' field to assist in index utilization. For example, the predicate
+ * If the original predicate is on the bucket's timeField we may also create a new loose
+ * predicate on the '_id' field to assist in index utilization. For example, the predicate
* {time: {$lt: new Date(...)}} will generate the following predicate:
* {$and: [
* {_id: {$lt: ObjectId(...)}},
@@ -147,7 +162,7 @@ public:
* When using IneligiblePredicatePolicy::kIgnore, if the predicate can't be pushed down, it
* returns null. When using IneligiblePredicatePolicy::kError it raises a user error.
*/
- static std::unique_ptr<MatchExpression> createPredicatesOnBucketLevelField(
+ static BucketPredicate createPredicatesOnBucketLevelField(
const MatchExpression* matchExpr,
const BucketSpec& bucketSpec,
int bucketMaxSpanSeconds,
@@ -269,7 +284,7 @@ public:
/**
* This resets the unpacker to prepare to unpack a new bucket described by the given document.
*/
- void reset(BSONObj&& bucket);
+ void reset(BSONObj&& bucket, bool bucketMatchedQuery = false);
Behavior behavior() const {
return _unpackerBehavior;
@@ -283,6 +298,10 @@ public:
return _bucket;
}
+ bool bucketMatchedQuery() const {
+ return _bucketMatchedQuery;
+ }
+
bool includeMetaField() const {
return _includeMetaField;
}
@@ -350,6 +369,9 @@ private:
bool _hasNext = false;
+ // A flag used to mark that the entire bucket matches the following $match predicate.
+ bool _bucketMatchedQuery = false;
+
// A flag used to mark that the timestamp value should be materialized in measurements.
bool _includeTimeField{false};
diff --git a/src/mongo/db/exec/collection_scan.cpp b/src/mongo/db/exec/collection_scan.cpp
index 1bee034cd0c..6cbacb0c997 100644
--- a/src/mongo/db/exec/collection_scan.cpp
+++ b/src/mongo/db/exec/collection_scan.cpp
@@ -206,8 +206,6 @@ PlanStage::StageState CollectionScan::doWork(WorkingSetID* out) {
<< "recordId: " << recordIdToSeek);
}
}
-
- return PlanStage::NEED_TIME;
}
if (_lastSeenId.isNull() && _params.direction == CollectionScanParams::FORWARD &&
diff --git a/src/mongo/db/exec/exclusion_projection_executor.cpp b/src/mongo/db/exec/exclusion_projection_executor.cpp
index 9823ed1b125..ad7d7ca9ddb 100644
--- a/src/mongo/db/exec/exclusion_projection_executor.cpp
+++ b/src/mongo/db/exec/exclusion_projection_executor.cpp
@@ -38,9 +38,10 @@ std::pair<BSONObj, bool> ExclusionNode::extractProjectOnFieldAndRename(const Str
BSONObjBuilder extractedExclusion;
// Check for a projection directly on 'oldName'. For example, {oldName: 0}.
- if (auto it = _projectedFields.find(oldName); it != _projectedFields.end()) {
+ if (auto it = _projectedFieldsSet.find(oldName); it != _projectedFieldsSet.end()) {
extractedExclusion.append(newName, false);
- _projectedFields.erase(it);
+ _projectedFieldsSet.erase(it);
+ _projectedFields.remove(std::string(oldName));
}
// Check for a projection on subfields of 'oldName'. For example, {oldName: {a: 0, b: 0}}.
diff --git a/src/mongo/db/exec/inclusion_projection_executor.cpp b/src/mongo/db/exec/inclusion_projection_executor.cpp
index 84117710063..30e06a76d82 100644
--- a/src/mongo/db/exec/inclusion_projection_executor.cpp
+++ b/src/mongo/db/exec/inclusion_projection_executor.cpp
@@ -68,7 +68,7 @@ void FastPathEligibleInclusionNode::_applyProjections(BSONObj bson, BSONObjBuild
const auto bsonElement{it.next()};
const auto fieldName{bsonElement.fieldNameStringData()};
- if (_projectedFields.find(fieldName) != _projectedFields.end()) {
+ if (_projectedFieldsSet.find(fieldName) != _projectedFieldsSet.end()) {
bob->append(bsonElement);
--nFieldsNeeded;
} else if (auto childIt = _children.find(fieldName); childIt != _children.end()) {
@@ -169,7 +169,8 @@ std::pair<BSONObj, bool> InclusionNode::extractComputedProjectionsInProject(
if (std::get<2>(expressionSpec)) {
// Replace the expression with an inclusion projected field.
- _projectedFields.insert(fieldName);
+ auto it = _projectedFields.insert(_projectedFields.end(), fieldName);
+ _projectedFieldsSet.insert(StringData(*it));
_expressions.erase(fieldName);
// Only computed projections at the beginning of the list were marked to become
// projected fields. The new projected field is at the beginning of the
diff --git a/src/mongo/db/exec/projection_node.cpp b/src/mongo/db/exec/projection_node.cpp
index 00fcf70946f..bbd2ebcac6b 100644
--- a/src/mongo/db/exec/projection_node.cpp
+++ b/src/mongo/db/exec/projection_node.cpp
@@ -48,7 +48,8 @@ void ProjectionNode::addProjectionForPath(const FieldPath& path) {
void ProjectionNode::_addProjectionForPath(const FieldPath& path) {
makeOptimizationsStale();
if (path.getPathLength() == 1) {
- _projectedFields.insert(path.fullPath());
+ auto it = _projectedFields.insert(_projectedFields.end(), path.fullPath());
+ _projectedFieldsSet.insert(StringData(*it));
return;
}
// FieldPath can't be empty, so it is safe to obtain the first path component here.
@@ -141,7 +142,7 @@ void ProjectionNode::applyProjections(const Document& inputDoc, MutableDocument*
while (it.more()) {
auto fieldName = it.fieldName();
- if (_projectedFields.find(fieldName) != _projectedFields.end()) {
+ if (_projectedFieldsSet.find(fieldName) != _projectedFieldsSet.end()) {
outputProjectedField(
fieldName, applyLeafProjectionToValue(it.next().second), outputDoc);
++projectedFields;
@@ -280,7 +281,7 @@ void ProjectionNode::serialize(boost::optional<ExplainOptions::Verbosity> explai
const bool projVal = !applyLeafProjectionToValue(Value(true)).missing();
// Always put "_id" first if it was projected (implicitly or explicitly).
- if (_projectedFields.find("_id") != _projectedFields.end()) {
+ if (_projectedFieldsSet.find("_id") != _projectedFieldsSet.end()) {
output->addField("_id", Value(projVal));
}
diff --git a/src/mongo/db/exec/projection_node.h b/src/mongo/db/exec/projection_node.h
index 82cae8b145b..61e9b98b89d 100644
--- a/src/mongo/db/exec/projection_node.h
+++ b/src/mongo/db/exec/projection_node.h
@@ -29,6 +29,8 @@
#pragma once
+#include <list>
+
#include "mongo/db/exec/projection_executor.h"
#include "mongo/db/query/projection_policies.h"
@@ -176,7 +178,14 @@ protected:
StringMap<std::unique_ptr<ProjectionNode>> _children;
StringMap<boost::intrusive_ptr<Expression>> _expressions;
- StringSet _projectedFields;
+
+ // List of the projected fields in the order in which they were specified.
+ std::list<std::string> _projectedFields;
+
+ // Set of projected fields. Note that the _projectedFields list actually owns the strings, and
+ // this StringDataSet simply holds views of those strings.
+ StringDataSet _projectedFieldsSet;
+
ProjectionPolicies _policies;
std::string _pathToNode;
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 05cf5d9a505..7fc2ba50718 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -240,6 +240,7 @@ env.Library(
],
LIBDEPS=[
"$BUILD_DIR/mongo/db/auth/authmocks",
+ "$BUILD_DIR/mongo/db/query/optimizer/unit_test_utils",
'$BUILD_DIR/mongo/db/query/query_test_service_context',
'$BUILD_DIR/mongo/db/query_exec',
'$BUILD_DIR/mongo/db/service_context_test_fixture',
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
index 67868053dd9..55767ac4b3f 100644
--- a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp
@@ -31,12 +31,12 @@
#include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h"
#include "mongo/db/pipeline/abt/document_source_visitor.h"
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/metadata_factory.h"
#include "mongo/db/query/optimizer/opt_phase_manager.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
#include "mongo/db/query/optimizer/rewrites/path_lower.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/unittest/unittest.h"
namespace mongo::optimizer {
@@ -379,8 +379,10 @@ TEST_F(NodeSBE, Lower1) {
true /*hasRID*/),
prefixId);
- OptPhaseManager phaseManager(
- OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager(OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
phaseManager.optimize(tree);
auto env = VariableEnvironment::build(tree);
@@ -464,15 +466,10 @@ TEST_F(NodeSBE, RequireRID) {
true /*hasRID*/),
prefixId);
- OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(),
- prefixId,
- true /*requireRID*/,
- {{{"test", createScanDef({}, {})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManagerRequireRID(OptPhaseManager::getAllRewritesSet(),
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
phaseManager.optimize(tree);
auto env = VariableEnvironment::build(tree);
diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp
index 7cc3dcc60ea..3f9c8517878 100644
--- a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp
+++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp
@@ -37,6 +37,7 @@
#include "mongo/db/query/cqf_command_utils.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/opt_phase_manager.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/db/query/plan_executor.h"
#include "mongo/db/query/plan_executor_factory.h"
#include "mongo/logv2/log.h"
@@ -114,8 +115,9 @@ std::vector<BSONObj> runSBEAST(OperationContext* opCtx,
OPTIMIZER_DEBUG_LOG(
6264807, 5, "SBE translated ABT", "explain"_attr = ExplainGenerator::explainV2(tree));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests);
+
phaseManager.optimize(tree);
OPTIMIZER_DEBUG_LOG(
diff --git a/src/mongo/db/exec/update_stage.cpp b/src/mongo/db/exec/update_stage.cpp
index ca7d4fb5ef5..a7ed460eb2f 100644
--- a/src/mongo/db/exec/update_stage.cpp
+++ b/src/mongo/db/exec/update_stage.cpp
@@ -239,6 +239,7 @@ BSONObj UpdateStage::transformAndUpdate(const Snapshotted<BSONObj>& oldObj,
if (!request->explain()) {
args.stmtIds = request->getStmtIds();
+ args.sampleId = request->getSampleId();
args.update = logObj;
if (_isUserInitiatedWrite) {
auto scopedCss = CollectionShardingState::assertCollectionLockedAndAcquire(
diff --git a/src/mongo/db/index_builds_coordinator.cpp b/src/mongo/db/index_builds_coordinator.cpp
index bb4f0239bf1..1c4e7ff8e3c 100644
--- a/src/mongo/db/index_builds_coordinator.cpp
+++ b/src/mongo/db/index_builds_coordinator.cpp
@@ -2615,7 +2615,7 @@ void IndexBuildsCoordinator::_scanCollectionAndInsertSortedKeysIntoIndex(
// impact on user operations. Other steps of the index builds such as the draining phase have
// normal priority because index builds are required to eventually catch-up with concurrent
// writers. Otherwise we risk never finishing the index build.
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
{
indexBuildsSSS.scanCollection.addAndFetch(1);
@@ -2650,7 +2650,7 @@ void IndexBuildsCoordinator::_insertSortedKeysIntoIndexForResume(
// impact on user operations. Other steps of the index builds such as the draining phase have
// normal priority because index builds are required to eventually catch-up with concurrent
// writers. Otherwise we risk never finishing the index build.
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
{
Lock::DBLock autoDb(opCtx, replState->dbName, MODE_IX);
const NamespaceStringOrUUID dbAndUUID(replState->dbName, replState->collectionUUID);
diff --git a/src/mongo/db/mirror_maestro.cpp b/src/mongo/db/mirror_maestro.cpp
index 172b7ce6bda..cf801a8947a 100644
--- a/src/mongo/db/mirror_maestro.cpp
+++ b/src/mongo/db/mirror_maestro.cpp
@@ -32,6 +32,7 @@
#include "mongo/db/mirror_maestro.h"
+#include "mongo/rpc/get_status_from_command_result.h"
#include <cmath>
#include <cstdlib>
#include <utility>
@@ -72,9 +73,10 @@ constexpr auto kMirroredReadsParamName = "mirrorReads"_sd;
constexpr auto kMirroredReadsSeenKey = "seen"_sd;
constexpr auto kMirroredReadsSentKey = "sent"_sd;
-constexpr auto kMirroredReadsReceivedKey = "received"_sd;
+constexpr auto kMirroredReadsProcessedAsSecondaryKey = "processedAsSecondary"_sd;
constexpr auto kMirroredReadsResolvedKey = "resolved"_sd;
constexpr auto kMirroredReadsResolvedBreakdownKey = "resolvedBreakdown"_sd;
+constexpr auto kMirroredReadsSucceededKey = "succeeded"_sd;
constexpr auto kMirroredReadsPendingKey = "pending"_sd;
MONGO_FAIL_POINT_DEFINE(mirrorMaestroExpectsResponse);
@@ -185,12 +187,13 @@ public:
BSONObjBuilder section;
section.append(kMirroredReadsSeenKey, seen.loadRelaxed());
section.append(kMirroredReadsSentKey, sent.loadRelaxed());
- section.append(kMirroredReadsReceivedKey, received.loadRelaxed());
+ section.append(kMirroredReadsProcessedAsSecondaryKey, processedAsSecondary.loadRelaxed());
if (MONGO_unlikely(mirrorMaestroExpectsResponse.shouldFail())) {
// We only can see if the command resolved if we got a response
section.append(kMirroredReadsResolvedKey, resolved.loadRelaxed());
section.append(kMirroredReadsResolvedBreakdownKey, resolvedBreakdown.toBSON());
+ section.append(kMirroredReadsSucceededKey, succeeded.loadRelaxed());
}
if (MONGO_unlikely(mirrorMaestroTracksPending.shouldFail())) {
section.append(kMirroredReadsPendingKey, pending.loadRelaxed());
@@ -232,13 +235,22 @@ public:
ResolvedBreakdownByHost resolvedBreakdown;
+ // Counts the number of operations (as primary) recognized as "to be mirrored".
AtomicWord<CounterT> seen;
+ // Counts the number of remote requests (for mirroring as primary) sent over the network.
AtomicWord<CounterT> sent;
+ // Counts the number of responses (as primary) from secondaries after mirrored operations.
AtomicWord<CounterT> resolved;
- // Counts the number of operations that are scheduled to be mirrored, but haven't yet been sent.
+ // Counts the number of responses (as primary) of successful mirrored operations. Disabled by
+ // default, hidden behind the mirrorMaestroExpectsResponse fail point.
+ AtomicWord<CounterT> succeeded;
+ // Counts the number of operations (as primary) that are scheduled to be mirrored, but
+ // haven't yet been sent. Disabled by default, hidden behind the mirrorMaestroTracksPending
+ // fail point.
AtomicWord<CounterT> pending;
- // Counts the number of mirrored operations received by this node as a secondary.
- AtomicWord<CounterT> received;
+ // Counts the number of mirrored operations processed successfully by this node as a
+ // secondary. Disabled by default, hidden behind the mirrorMaestroExpectsResponse fail point.
+ AtomicWord<CounterT> processedAsSecondary;
} gMirroredReadsSection;
auto parseMirroredReadsParameters(const BSONObj& obj) {
@@ -306,7 +318,7 @@ void MirrorMaestro::tryMirrorRequest(OperationContext* opCtx) noexcept {
void MirrorMaestro::onReceiveMirroredRead(OperationContext* opCtx) noexcept {
const auto& invocation = CommandInvocation::get(opCtx);
if (MONGO_unlikely(invocation->isMirrored())) {
- gMirroredReadsSection.received.fetchAndAddRelaxed(1);
+ gMirroredReadsSection.processedAsSecondary.fetchAndAddRelaxed(1);
}
}
@@ -418,6 +430,11 @@ void MirrorMaestroImpl::_mirror(const std::vector<HostAndPort>& hosts,
// Count both failed and successful reads as resolved
gMirroredReadsSection.resolved.fetchAndAdd(1);
gMirroredReadsSection.resolvedBreakdown.onResponseReceived(host);
+
+ if (getStatusFromCommandResult(args.response.data).isOK()) {
+ gMirroredReadsSection.succeeded.fetchAndAdd(1);
+ }
+
LOGV2_DEBUG(
31457, 4, "Response received", "host"_attr = host, "response"_attr = args.response);
diff --git a/src/mongo/db/mongod_main.cpp b/src/mongo/db/mongod_main.cpp
index 65b6d06f112..127f694b9d5 100644
--- a/src/mongo/db/mongod_main.cpp
+++ b/src/mongo/db/mongod_main.cpp
@@ -142,6 +142,7 @@
#include "mongo/db/s/op_observer_sharding_impl.h"
#include "mongo/db/s/periodic_sharded_index_consistency_checker.h"
#include "mongo/db/s/query_analysis_op_observer.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/s/rename_collection_participant_service.h"
#include "mongo/db/s/resharding/resharding_coordinator_service.h"
#include "mongo/db/s/resharding/resharding_donor_service.h"
@@ -856,6 +857,10 @@ ExitCode _initAndListen(ServiceContext* serviceContext, int listenPort) {
auto catalog = std::make_unique<StatsCatalog>(serviceContext, std::move(cacheLoader));
StatsCatalog::set(serviceContext, std::move(catalog));
+ if (analyze_shard_key::supportsPersistingSampledQueriesIgnoreFCV()) {
+ analyze_shard_key::QueryAnalysisWriter::get(serviceContext).onStartup();
+ }
+
// MessageServer::run will return when exit code closes its socket and we don't need the
// operation context anymore
startupOpCtx.reset();
@@ -1174,6 +1179,8 @@ void setUpObservers(ServiceContext* serviceContext) {
std::make_unique<OplogWriterTransactionProxy>(std::make_unique<OplogWriterImpl>())));
opObserverRegistry->addObserver(std::make_unique<ShardServerOpObserver>());
opObserverRegistry->addObserver(std::make_unique<ReshardingOpObserver>());
+ opObserverRegistry->addObserver(
+ std::make_unique<analyze_shard_key::QueryAnalysisOpObserver>());
opObserverRegistry->addObserver(std::make_unique<repl::TenantMigrationDonorOpObserver>());
opObserverRegistry->addObserver(
std::make_unique<repl::TenantMigrationRecipientOpObserver>());
@@ -1314,6 +1321,11 @@ void shutdownTask(const ShutdownTaskArgs& shutdownArgs) {
lsc->joinOnShutDown();
}
+ if (analyze_shard_key::supportsPersistingSampledQueriesIgnoreFCV()) {
+ LOGV2(7047303, "Shutting down the QueryAnalysisWriter");
+ analyze_shard_key::QueryAnalysisWriter::get(serviceContext).onShutdown();
+ }
+
// Shutdown the TransportLayer so that new connections aren't accepted
if (auto tl = serviceContext->getTransportLayer()) {
LOGV2_OPTIONS(
diff --git a/src/mongo/db/namespace_string.cpp b/src/mongo/db/namespace_string.cpp
index a45f99552f5..dab24a3c482 100644
--- a/src/mongo/db/namespace_string.cpp
+++ b/src/mongo/db/namespace_string.cpp
@@ -195,6 +195,12 @@ const NamespaceString NamespaceString::kGlobalIndexClonerNamespace(
const NamespaceString NamespaceString::kConfigQueryAnalyzersNamespace(NamespaceString::kConfigDb,
"queryAnalyzers");
+const NamespaceString NamespaceString::kConfigSampledQueriesNamespace(NamespaceString::kConfigDb,
+ "sampledQueries");
+
+const NamespaceString NamespaceString::kConfigSampledQueriesDiffNamespace(
+ NamespaceString::kConfigDb, "sampledQueriesDiff");
+
NamespaceString NamespaceString::parseFromStringExpectTenantIdInMultitenancyMode(StringData ns) {
if (!gMultitenancySupport) {
return NamespaceString(boost::none, ns);
@@ -495,12 +501,12 @@ bool NamespaceString::isSystemStatsCollection() const {
}
NamespaceString NamespaceString::makeTimeseriesBucketsNamespace() const {
- return {db(), kTimeseriesBucketsCollectionPrefix.toString() + coll()};
+ return {dbName(), kTimeseriesBucketsCollectionPrefix.toString() + coll()};
}
NamespaceString NamespaceString::getTimeseriesViewNamespace() const {
invariant(isTimeseriesBucketsCollection(), ns());
- return {db(), coll().substr(kTimeseriesBucketsCollectionPrefix.size())};
+ return {dbName(), coll().substr(kTimeseriesBucketsCollectionPrefix.size())};
}
bool NamespaceString::isImplicitlyReplicated() const {
diff --git a/src/mongo/db/namespace_string.h b/src/mongo/db/namespace_string.h
index 904bac9f99d..3c746ad0ddf 100644
--- a/src/mongo/db/namespace_string.h
+++ b/src/mongo/db/namespace_string.h
@@ -273,6 +273,12 @@ public:
// Namespace used for storing query analyzer settings.
static const NamespaceString kConfigQueryAnalyzersNamespace;
+ // Namespace used for storing sampled queries.
+ static const NamespaceString kConfigSampledQueriesNamespace;
+
+ // Namespace used for storing the diffs for sampled update queries.
+ static const NamespaceString kConfigSampledQueriesDiffNamespace;
+
/**
* Constructs an empty NamespaceString.
*/
diff --git a/src/mongo/db/operation_context.cpp b/src/mongo/db/operation_context.cpp
index c17d9c2d7b3..24c92e19f6b 100644
--- a/src/mongo/db/operation_context.cpp
+++ b/src/mongo/db/operation_context.cpp
@@ -80,7 +80,7 @@ OperationContext::OperationContext(Client* client, OperationIdSlot&& opIdSlot)
: _client(client),
_opId(std::move(opIdSlot)),
_elapsedTime(client ? client->getServiceContext()->getTickSource()
- : SystemTickSource::get()) {}
+ : globalSystemTickSource()) {}
OperationContext::~OperationContext() {
releaseOperationKey();
diff --git a/src/mongo/db/ops/SConscript b/src/mongo/db/ops/SConscript
index 1c0819eec70..5f3d894c3db 100644
--- a/src/mongo/db/ops/SConscript
+++ b/src/mongo/db/ops/SConscript
@@ -48,6 +48,7 @@ env.Library(
'$BUILD_DIR/mongo/db/record_id_helpers',
'$BUILD_DIR/mongo/db/repl/oplog',
'$BUILD_DIR/mongo/db/repl/repl_coordinator_interface',
+ '$BUILD_DIR/mongo/db/s/query_analysis_writer',
'$BUILD_DIR/mongo/db/shard_role',
'$BUILD_DIR/mongo/db/stats/counters',
'$BUILD_DIR/mongo/db/stats/server_read_concern_write_concern_metrics',
diff --git a/src/mongo/db/ops/update_request.h b/src/mongo/db/ops/update_request.h
index ed775c9e15e..e635ce7f156 100644
--- a/src/mongo/db/ops/update_request.h
+++ b/src/mongo/db/ops/update_request.h
@@ -76,7 +76,7 @@ public:
};
UpdateRequest(const write_ops::UpdateOpEntry& updateOp = write_ops::UpdateOpEntry())
- : _updateOp(updateOp) {}
+ : _updateOp(updateOp), _sampleId(updateOp.getSampleId()) {}
void setNamespaceString(const NamespaceString& nsString) {
_nsString = nsString;
@@ -257,6 +257,14 @@ public:
return _stmtIds;
}
+ void setSampleId(boost::optional<UUID> sampleId) {
+ _sampleId = sampleId;
+ }
+
+ const boost::optional<UUID>& getSampleId() const {
+ return _sampleId;
+ }
+
std::string toString() const {
StringBuilder builder;
builder << " query: " << getQuery();
@@ -314,6 +322,9 @@ private:
// The statement ids of this request.
std::vector<StmtId> _stmtIds = {kUninitializedStmtId};
+ // The unique sample id for this request if it has been chosen for sampling.
+ boost::optional<UUID> _sampleId;
+
// Flags controlling the update.
// God bypasses _id checking and index generation. It is only used on behalf of system
diff --git a/src/mongo/db/ops/write_ops.idl b/src/mongo/db/ops/write_ops.idl
index a0567a957f8..be734dae3a6 100644
--- a/src/mongo/db/ops/write_ops.idl
+++ b/src/mongo/db/ops/write_ops.idl
@@ -259,6 +259,11 @@ structs:
type: object
optional: true
stability: stable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
DeleteOpEntry:
description: "Parser for the entries in the 'deletes' array of a delete command."
@@ -285,6 +290,11 @@ structs:
type: object
optional: true
stability: stable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
FindAndModifyLastError:
description: "Contains execution details for the findAndModify command"
@@ -543,3 +553,8 @@ commands:
description: "Indicates whether the operation is a mirrored read"
type: optionalBool
stability: unstable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
diff --git a/src/mongo/db/ops/write_ops_exec.cpp b/src/mongo/db/ops/write_ops_exec.cpp
index f3a37117b00..e98a5931dd8 100644
--- a/src/mongo/db/ops/write_ops_exec.cpp
+++ b/src/mongo/db/ops/write_ops_exec.cpp
@@ -67,6 +67,7 @@
#include "mongo/db/repl/tenant_migration_decoration.h"
#include "mongo/db/s/collection_sharding_state.h"
#include "mongo/db/s/operation_sharding_state.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/db/stats/counters.h"
#include "mongo/db/stats/server_write_concern_metrics.h"
#include "mongo/db/stats/top.h"
@@ -677,7 +678,7 @@ WriteResult performInserts(OperationContext* opCtx,
bool containsRetry = false;
ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); });
- size_t stmtIdIndex = 0;
+ size_t nextOpIndex = 0;
size_t bytesInBatch = 0;
std::vector<InsertStatement> batch;
const size_t maxBatchSize = internalInsertMaxBatchSize.load();
@@ -685,10 +686,11 @@ WriteResult performInserts(OperationContext* opCtx,
batch.reserve(std::min(wholeOp.getDocuments().size(), maxBatchSize));
for (auto&& doc : wholeOp.getDocuments()) {
+ const auto currentOpIndex = nextOpIndex++;
const bool isLastDoc = (&doc == &wholeOp.getDocuments().back());
bool containsDotsAndDollarsField = false;
auto fixedDoc = fixDocumentForInsert(opCtx, doc, &containsDotsAndDollarsField);
- const StmtId stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++);
+ const StmtId stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex);
const bool wasAlreadyExecuted = opCtx->isRetryableWrite() &&
txnParticipant.checkStatementExecutedNoOplogEntryFetch(opCtx, stmtId);
@@ -1041,7 +1043,7 @@ WriteResult performUpdates(OperationContext* opCtx,
bool containsRetry = false;
ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); });
- size_t stmtIdIndex = 0;
+ size_t nextOpIndex = 0;
WriteResult out;
out.results.reserve(wholeOp.getUpdates().size());
@@ -1054,7 +1056,8 @@ WriteResult performUpdates(OperationContext* opCtx,
// updates.
bool forgoOpCounterIncrements = false;
for (auto&& singleOp : wholeOp.getUpdates()) {
- const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++);
+ const auto currentOpIndex = nextOpIndex++;
+ const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex);
if (opCtx->isRetryableWrite()) {
if (auto entry = txnParticipant.checkStatementExecuted(opCtx, stmtId)) {
containsRetry = true;
@@ -1081,6 +1084,13 @@ WriteResult performUpdates(OperationContext* opCtx,
finishCurOp(opCtx, &*curOp);
}
});
+
+ if (analyze_shard_key::supportsPersistingSampledQueries() && singleOp.getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addUpdateQuery(wholeOp, currentOpIndex)
+ .getAsync([](auto) {});
+ }
+
try {
lastOpFixer.startingOp();
@@ -1276,7 +1286,7 @@ WriteResult performDeletes(OperationContext* opCtx,
bool containsRetry = false;
ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); });
- size_t stmtIdIndex = 0;
+ size_t nextOpIndex = 0;
WriteResult out;
out.results.reserve(wholeOp.getDeletes().size());
@@ -1286,7 +1296,8 @@ WriteResult performDeletes(OperationContext* opCtx,
wholeOp.getLegacyRuntimeConstants().value_or(Variables::generateRuntimeConstants(opCtx));
for (auto&& singleOp : wholeOp.getDeletes()) {
- const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++);
+ const auto currentOpIndex = nextOpIndex++;
+ const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex);
if (opCtx->isRetryableWrite() &&
txnParticipant.checkStatementExecutedNoOplogEntryFetch(opCtx, stmtId)) {
containsRetry = true;
@@ -1316,6 +1327,13 @@ WriteResult performDeletes(OperationContext* opCtx,
&hangBeforeChildRemoveOpIsPopped, opCtx, "hangBeforeChildRemoveOpIsPopped");
}
});
+
+ if (analyze_shard_key::supportsPersistingSampledQueries() && singleOp.getSampleId()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addDeleteQuery(wholeOp, currentOpIndex)
+ .getAsync([](auto) {});
+ }
+
try {
lastOpFixer.startingOp();
out.results.push_back(performSingleDeleteOp(opCtx,
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 2347bd26c10..b8b17ac1ab9 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -503,6 +503,22 @@ env.Library(
LIBDEPS_PRIVATE=[],
)
+env.Benchmark(
+ target='abt_translation_bm',
+ source=[
+ 'abt/abt_translate_bm_fixture.cpp',
+ 'abt/abt_translate_cq_bm.cpp',
+ 'abt/abt_translate_pipeline_bm.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/db/pipeline/pipeline',
+ '$BUILD_DIR/mongo/db/query/canonical_query',
+ '$BUILD_DIR/mongo/db/query/query_test_service_context',
+ '$BUILD_DIR/mongo/unittest/unittest',
+ '$BUILD_DIR/mongo/util/processinfo',
+ ],
+)
+
env.CppUnitTest(
target='db_pipeline_test',
source=[
diff --git a/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp
new file mode 100644
index 00000000000..6b00208dbbd
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp
@@ -0,0 +1,211 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h"
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/db/json.h"
+
+namespace mongo {
+namespace {
+BSONArray buildArray(int size) {
+ BSONArrayBuilder builder;
+ for (int i = 0; i < size; i++) {
+ builder.append(i);
+ }
+ return builder.arr();
+}
+
+std::string getField(int index) {
+ static constexpr StringData kViableChars = "abcdefghijklmnopqrstuvwxyz"_sd;
+ invariant(size_t(index) < kViableChars.size());
+ return std::string(1, kViableChars[index]);
+}
+
+BSONObj buildSimpleBSONSpec(int nFields, bool isMatch, bool isExclusion = false) {
+ BSONObjBuilder spec;
+ for (auto i = 0; i < nFields; i++) {
+ int val = isMatch ? i : (isExclusion ? 0 : 1);
+ spec.append(getField(i), val);
+ }
+ return spec.obj();
+}
+
+/**
+ * Builds a filter BSON with 'nFields' simple equality predicates.
+ */
+BSONObj buildSimpleMatchSpec(int nFields) {
+ return buildSimpleBSONSpec(nFields, true /*isMatch*/);
+}
+
+/**
+ * Builds a projection BSON with 'nFields' simple inclusions or exclusions, depending on the
+ * 'isExclusion' parameter.
+ */
+BSONObj buildSimpleProjectSpec(int nFields, bool isExclusion) {
+ return buildSimpleBSONSpec(nFields, false /*isMatch*/, isExclusion);
+}
+
+BSONObj buildNestedBSONSpec(int depth, bool isExclusion, int offset) {
+ std::string field;
+ for (auto i = 0; i < depth - 1; i++) {
+ field += getField(offset + i) += ".";
+ }
+ field += getField(offset + depth);
+
+ return BSON(field << (isExclusion ? 0 : 1));
+}
+
+/**
+ * Builds a BSON representing a predicate on one dotted path, where the field has depth 'depth'.
+ */
+BSONObj buildNestedMatchSpec(int depth, int offset = 0) {
+ return buildNestedBSONSpec(depth, false /*isExclusion*/, offset);
+}
+
+/**
+ * Builds a BSON representing a projection on one dotted path, where the field has depth 'depth'.
+ */
+BSONObj buildNestedProjectSpec(int depth, bool isExclusion, int offset = 0) {
+ return buildNestedBSONSpec(depth, isExclusion, offset);
+}
+} // namespace
+
+void ABTTranslateBenchmarkFixture::benchmarkMatch(benchmark::State& state) {
+ auto match = buildSimpleMatchSpec(1);
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+void ABTTranslateBenchmarkFixture::benchmarkMatchTwoFields(benchmark::State& state) {
+ auto match = buildSimpleMatchSpec(2);
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchTwentyFields(benchmark::State& state) {
+ auto match = buildSimpleMatchSpec(20);
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchDepthTwo(benchmark::State& state) {
+ auto match = buildNestedMatchSpec(2);
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchDepthTwenty(benchmark::State& state) {
+ auto match = buildNestedMatchSpec(20);
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchGtLt(benchmark::State& state) {
+ auto match = fromjson("{a: {$gt: -12, $lt: 5}}");
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchIn(benchmark::State& state) {
+ auto match = BSON("a" << BSON("$in" << buildArray(10)));
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchInLarge(benchmark::State& state) {
+ auto match = BSON("a" << BSON("$in" << buildArray(1000)));
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchElemMatch(benchmark::State& state) {
+ auto match = fromjson("{a: {$elemMatch: {b: {$eq: 2}, c: {$lt: 3}}}}");
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkMatchComplex(benchmark::State& state) {
+ auto match = fromjson(
+ "{$and: ["
+ "{'a.b': {$not: {$eq: 2}}},"
+ "{'b.c': {$lte: {$eq: 'str'}}},"
+ "{$or: [{'c.d' : {$eq: 3}}, {'d.e': {$eq: 4}}]},"
+ "{$or: ["
+ "{'e.f': {$gt: 4}},"
+ "{$and: ["
+ "{'f.g': {$not: {$eq: 1}}},"
+ "{'g.h': {$eq: 3}}"
+ "]}"
+ "]}"
+ "]}}");
+ benchmarkABTTranslate(state, match, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectExclude(benchmark::State& state) {
+ auto project = buildSimpleProjectSpec(1, true /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectInclude(benchmark::State& state) {
+ auto project = buildSimpleProjectSpec(1, false /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeTwoFields(benchmark::State& state) {
+ auto project = buildSimpleProjectSpec(2, false /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeTwentyFields(benchmark::State& state) {
+ auto project = buildSimpleProjectSpec(20, false /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeDepthTwo(benchmark::State& state) {
+ auto project = buildNestedProjectSpec(2, false /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeDepthTwenty(benchmark::State& state) {
+ auto project = buildNestedProjectSpec(20, false /*isExclusion*/);
+ benchmarkABTTranslate(state, project, BSONObj());
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkTwoStages(benchmark::State& state) {
+ // Builds a match on a nested field and then excludes that nested field.
+ std::vector<BSONObj> pipeline;
+ pipeline.push_back(BSON("$match" << buildNestedMatchSpec(3)));
+ pipeline.push_back(BSON("$project" << buildNestedProjectSpec(3, true /*isExclusion*/)));
+ benchmarkABTTranslate(state, pipeline);
+}
+
+void ABTTranslateBenchmarkFixture::benchmarkTwentyStages(benchmark::State& state) {
+ // Builds a sequence of alternating $match and $project stages which match on a nested field and
+ // then exclude that field.
+ std::vector<BSONObj> pipeline;
+ for (int i = 0; i < 10; i++) {
+ pipeline.push_back(BSON("$match" << buildNestedMatchSpec(3, i)));
+ pipeline.push_back(BSON("$project" << buildNestedProjectSpec(3, true /*exclusion*/, i)));
+ }
+ benchmarkABTTranslate(state, pipeline);
+}
+
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h
new file mode 100644
index 00000000000..0ab40ed52fe
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h
@@ -0,0 +1,137 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/platform/basic.h"
+
+#include <benchmark/benchmark.h>
+
+#include "mongo/bson/bsonobj.h"
+
+namespace mongo {
+
+class ABTTranslateBenchmarkFixture : public benchmark::Fixture {
+public:
+ virtual void benchmarkABTTranslate(benchmark::State& state,
+ BSONObj matchSpec,
+ BSONObj projectSpec) = 0;
+ virtual void benchmarkABTTranslate(benchmark::State& state,
+ const std::vector<BSONObj>& pipeline) = 0;
+
+ void benchmarkMatch(benchmark::State& state);
+ void benchmarkMatchTwoFields(benchmark::State& state);
+ void benchmarkMatchTwentyFields(benchmark::State& state);
+
+ void benchmarkMatchDepthTwo(benchmark::State& state);
+ void benchmarkMatchDepthTwenty(benchmark::State& state);
+
+ void benchmarkMatchGtLt(benchmark::State& state);
+ void benchmarkMatchIn(benchmark::State& state);
+ void benchmarkMatchInLarge(benchmark::State& state);
+ void benchmarkMatchElemMatch(benchmark::State& state);
+ void benchmarkMatchComplex(benchmark::State& state);
+
+ void benchmarkProjectExclude(benchmark::State& state);
+ void benchmarkProjectInclude(benchmark::State& state);
+ void benchmarkProjectIncludeTwoFields(benchmark::State& state);
+ void benchmarkProjectIncludeTwentyFields(benchmark::State& state);
+
+ void benchmarkProjectIncludeDepthTwo(benchmark::State& state);
+ void benchmarkProjectIncludeDepthTwenty(benchmark::State& state);
+
+ void benchmarkTwoStages(benchmark::State& state);
+ void benchmarkTwentyStages(benchmark::State& state);
+};
+
+// These benchmarks cover some simple queries which are currently CQF-eligible. As more support
+// is added to CQF, more benchmarks may be added here as needed.
+#define BENCHMARK_MQL_TRANSLATION(Fixture) \
+ \
+ BENCHMARK_F(Fixture, Match)(benchmark::State & state) { \
+ benchmarkMatch(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchTwoFields)(benchmark::State & state) { \
+ benchmarkMatchTwoFields(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchTwentyFields)(benchmark::State & state) { \
+ benchmarkMatchTwentyFields(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchDepthTwo)(benchmark::State & state) { \
+ benchmarkMatchDepthTwo(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchDepthTwenty)(benchmark::State & state) { \
+ benchmarkMatchDepthTwenty(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchGtLt)(benchmark::State & state) { \
+ benchmarkMatchGtLt(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchIn)(benchmark::State & state) { \
+ benchmarkMatchIn(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchInLarge)(benchmark::State & state) { \
+ benchmarkMatchInLarge(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchElemMatch)(benchmark::State & state) { \
+ benchmarkMatchElemMatch(state); \
+ } \
+ BENCHMARK_F(Fixture, MatchComplex)(benchmark::State & state) { \
+ benchmarkMatchComplex(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectExclude)(benchmark::State & state) { \
+ benchmarkProjectExclude(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectInclude)(benchmark::State & state) { \
+ benchmarkProjectInclude(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectIncludeTwoFields)(benchmark::State & state) { \
+ benchmarkProjectIncludeTwoFields(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectIncludeTwentyFields)(benchmark::State & state) { \
+ benchmarkProjectIncludeTwentyFields(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectIncludeDepthTwo)(benchmark::State & state) { \
+ benchmarkProjectIncludeDepthTwo(state); \
+ } \
+ BENCHMARK_F(Fixture, ProjectIncludeDepthTwenty)(benchmark::State & state) { \
+ benchmarkProjectIncludeDepthTwenty(state); \
+ }
+
+// Queries which are expressed as pipelines should be added here because they cannot go through
+// find translation.
+#define BENCHMARK_MQL_PIPELINE_TRANSLATION(Fixture) \
+ \
+ BENCHMARK_F(Fixture, TwoStages)(benchmark::State & state) { \
+ benchmarkTwoStages(state); \
+ } \
+ BENCHMARK_F(Fixture, TwentyStages)(benchmark::State & state) { \
+ benchmarkTwentyStages(state); \
+ }
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp b/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp
new file mode 100644
index 00000000000..8603ea2d7c0
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp
@@ -0,0 +1,88 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <benchmark/benchmark.h>
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h"
+#include "mongo/db/pipeline/abt/canonical_query_translation.h"
+#include "mongo/db/query/canonical_query.h"
+#include "mongo/db/query/query_test_service_context.h"
+
+namespace mongo::optimizer {
+namespace {
+/**
+ * Benchmarks translation from CanonicalQuery to ABT.
+ */
+class CanonicalQueryABTTranslate : public ABTTranslateBenchmarkFixture {
+public:
+ CanonicalQueryABTTranslate() {}
+
+ void benchmarkABTTranslate(benchmark::State& state,
+ const std::vector<BSONObj>& pipeline) override final {
+ state.SkipWithError("Find translation fixture cannot translate a pieline");
+ return;
+ }
+
+ void benchmarkABTTranslate(benchmark::State& state,
+ BSONObj matchSpec,
+ BSONObj projectSpec) override final {
+ QueryTestServiceContext testServiceContext;
+ auto opCtx = testServiceContext.makeOperationContext();
+ auto nss = NamespaceString("test.bm");
+
+ Metadata metadata{{}};
+ PrefixId prefixId;
+ std::string scanProjName = prefixId.getNextId("scan");
+
+ auto findCommand = std::make_unique<FindCommandRequest>(nss);
+ findCommand->setFilter(matchSpec);
+ findCommand->setProjection(projectSpec);
+ auto cq = CanonicalQuery::canonicalize(opCtx.get(), std::move(findCommand));
+ if (!cq.isOK()) {
+ state.SkipWithError("Canonical query could not be created");
+ return;
+ }
+
+ // This is where recording starts.
+ for (auto keepRunning : state) {
+ benchmark::DoNotOptimize(
+ translateCanonicalQueryToABT(metadata,
+ *cq.getValue(),
+ scanProjName,
+ make<ScanNode>(scanProjName, "collection"),
+ prefixId));
+ benchmark::ClobberMemory();
+ }
+ }
+};
+
+BENCHMARK_MQL_TRANSLATION(CanonicalQueryABTTranslate)
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp b/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp
new file mode 100644
index 00000000000..f878c494c13
--- /dev/null
+++ b/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp
@@ -0,0 +1,90 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <benchmark/benchmark.h>
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h"
+#include "mongo/db/pipeline/abt/document_source_visitor.h"
+#include "mongo/db/pipeline/expression_context_for_test.h"
+#include "mongo/db/query/query_test_service_context.h"
+
+namespace mongo::optimizer {
+namespace {
+/**
+ * Benchmarks translation from optimized Pipeline to ABT.
+ */
+class PipelineABTTranslateBenchmark : public ABTTranslateBenchmarkFixture {
+public:
+ PipelineABTTranslateBenchmark() {}
+
+ void benchmarkABTTranslate(benchmark::State& state,
+ BSONObj matchSpec,
+ BSONObj projectSpec) override final {
+ std::vector<BSONObj> pipeline;
+ if (!matchSpec.isEmpty()) {
+ pipeline.push_back(BSON("$match" << matchSpec));
+ }
+ if (!projectSpec.isEmpty()) {
+ pipeline.push_back(BSON("$project" << projectSpec));
+ }
+ benchmarkABTTranslate(state, pipeline);
+ }
+
+ void benchmarkABTTranslate(benchmark::State& state,
+ const std::vector<BSONObj>& pipeline) override final {
+ QueryTestServiceContext testServiceContext;
+ auto opCtx = testServiceContext.makeOperationContext();
+ auto expCtx = new ExpressionContextForTest(opCtx.get(), NamespaceString("test.bm"));
+
+ Metadata metadata{{}};
+ PrefixId prefixId;
+ std::string scanProjName = prefixId.getNextId("scan");
+
+ std::unique_ptr<Pipeline, PipelineDeleter> parsedPipeline =
+ Pipeline::parse(pipeline, expCtx);
+ parsedPipeline->optimizePipeline();
+
+ // This is where recording starts.
+ for (auto keepRunning : state) {
+ benchmark::DoNotOptimize(
+ translatePipelineToABT(metadata,
+ *parsedPipeline,
+ scanProjName,
+ make<ScanNode>(scanProjName, "collection"),
+ prefixId));
+ benchmark::ClobberMemory();
+ }
+ }
+};
+
+BENCHMARK_MQL_TRANSLATION(PipelineABTTranslateBenchmark)
+BENCHMARK_MQL_PIPELINE_TRANSLATION(PipelineABTTranslateBenchmark)
+} // namespace
+} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h
index 7cbf740c2e2..7c852281078 100644
--- a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h
+++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h
@@ -31,6 +31,7 @@
#include <stack>
+#include "mongo/db/matcher/expression_path.h"
#include "mongo/db/query/optimizer/node.h"
#include "mongo/db/query/optimizer/utils/utils.h"
@@ -73,17 +74,37 @@ public:
*/
std::string getNextId(const std::string& key);
- void enterElemMatch() {
- _elemMatchCount++;
+ void enterElemMatch(const MatchExpression::MatchType matchType) {
+ _elemMatchStack.push_back(matchType);
}
void exitElemMatch() {
tassert(6809501, "Attempting to exit out of elemMatch that was not entered", inElemMatch());
- _elemMatchCount--;
+ _elemMatchStack.pop_back();
}
bool inElemMatch() {
- return _elemMatchCount > 0;
+ return !_elemMatchStack.empty();
+ }
+
+ /**
+ * Returns whether the current $elemMatch should consider its path for translation. This
+ * function assumes that 'enterElemMatch' has been called before visiting the current
+ * expression.
+ */
+ bool shouldGeneratePathForElemMatch() const {
+ return _elemMatchStack.size() == 1 ||
+ _elemMatchStack[_elemMatchStack.size() - 2] ==
+ MatchExpression::MatchType::ELEM_MATCH_OBJECT;
+ }
+
+ /**
+ * Returns true if the current expression should consider its path for translation based on
+ * whether it's contained within an ElemMatchObjectExpression.
+ */
+ bool shouldGeneratePath() const {
+ return _elemMatchStack.empty() ||
+ _elemMatchStack.back() == MatchExpression::MatchType::ELEM_MATCH_OBJECT;
}
private:
@@ -103,8 +124,9 @@ private:
// child expressions.
std::stack<ABT> _stack;
- // Track whether the vistor is currently under an $elemMatch node.
- int _elemMatchCount{0};
+ // Used to track expressions contained under an $elemMatch. Each entry is either an
+ // ELEM_MATCH_OBJECT or ELEM_MATCH_VALUE.
+ std::vector<MatchExpression::MatchType> _elemMatchStack;
};
} // namespace mongo::optimizer
diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
index e422dc24d63..70be6fd3651 100644
--- a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
+++ b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp
@@ -73,11 +73,11 @@ public:
ABTMatchExpressionPreVisitor(ExpressionAlgebrizerContext& ctx) : _ctx(ctx) {}
void visit(const ElemMatchObjectMatchExpression* expr) override {
- _ctx.enterElemMatch();
+ _ctx.enterElemMatch(expr->matchType());
}
void visit(const ElemMatchValueMatchExpression* expr) override {
- _ctx.enterElemMatch();
+ _ctx.enterElemMatch(expr->matchType());
}
private:
@@ -135,8 +135,8 @@ public:
assertSupportedPathExpression(expr);
ABT result = make<PathDefault>(Constant::boolean(false));
- if (!expr->path().empty()) {
- result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ if (shouldGeneratePath(expr)) {
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
_ctx.push(std::move(result));
}
@@ -227,9 +227,9 @@ public:
maybeComposePath<PathComposeA>(result, make<PathDefault>(Constant::boolean(true)));
}
- // The path can be empty if we are within an $elemMatch. In this case elemMatch would
- // insert a traverse.
- if (!expr->path().empty()) {
+ // Do not insert a traverse if within an $elemMatch; traversal will be handled by the
+ // $elemMatch expression itself.
+ if (shouldGeneratePath(expr)) {
// When the path we are comparing is a path to an array, the comparison is
// considered true if it evaluates to true for the array itself or for any of the
// array’s elements. 'result' evaluates comparison on the array elements, and
@@ -249,7 +249,7 @@ public:
make<Constant>(tagArraysOnly, valArraysOnly)));
arrOnlyGuard.reset();
}
- result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
_ctx.push(std::move(result));
}
@@ -426,14 +426,8 @@ public:
make<FunctionCall>("getArraySize", makeSeq(make<Variable>(lambdaProjName))),
Constant::int64(expr->getData()))));
- if (!expr->path().empty()) {
- // No traverse.
- result = translateFieldPath(
- FieldPath(expr->path().toString()),
- std::move(result),
- [](const std::string& fieldName, const bool /*isLastElement*/, ABT input) {
- return make<PathGet>(fieldName, std::move(input));
- });
+ if (shouldGeneratePath(expr)) {
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
_ctx.push(std::move(result));
}
@@ -460,9 +454,7 @@ public:
makeSeq(make<Variable>(lambdaProjName),
Constant::int32(expr->typeSet().getBSONTypeMask())))));
- // The path can be empty if we are within an $elemMatch. In this case elemMatch would insert
- // a traverse.
- if (!expr->path().empty()) {
+ if (shouldGeneratePath(expr)) {
result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel);
if (expr->typeSet().hasType(BSONType::Array)) {
// If we are testing against array type, insert a comparison against the
@@ -470,7 +462,7 @@ public:
result = make<PathComposeA>(make<PathArr>(), std::move(result));
}
- result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
_ctx.push(std::move(result));
}
@@ -517,33 +509,13 @@ private:
// Make sure we consider only arrays fields on the path.
maybeComposePath(result, make<PathArr>());
- if (!expr->path().empty()) {
- result = translateFieldPath(
- FieldPath{expr->path().toString()},
- std::move(result),
- [&](const std::string& fieldName, const bool isLastElement, ABT input) {
- if (!isLastElement) {
- input = make<PathTraverse>(std::move(input), PathTraverse::kSingleLevel);
- }
- return make<PathGet>(fieldName, std::move(input));
- });
+ if (shouldGeneratePath(expr)) {
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
_ctx.push(std::move(result));
}
- ABT generateFieldPath(const FieldPath& fieldPath, ABT initial) {
- return translateFieldPath(
- fieldPath,
- std::move(initial),
- [&](const std::string& fieldName, const bool isLastElement, ABT input) {
- if (!isLastElement) {
- input = make<PathTraverse>(std::move(input), PathTraverse::kSingleLevel);
- }
- return make<PathGet>(fieldName, std::move(input));
- });
- }
-
void assertSupportedPathExpression(const PathMatchExpression* expr) {
uassert(ErrorCodes::InternalErrorNotSupported,
"Expression contains a numeric path component",
@@ -610,9 +582,7 @@ private:
break;
}
- // The path can be empty if we are within an $elemMatch. In this case elemMatch would
- // insert a traverse.
- if (!expr->path().empty()) {
+ if (shouldGeneratePath(expr)) {
if (tag == sbe::value::TypeTags::Array || tag == sbe::value::TypeTags::MinKey ||
tag == sbe::value::TypeTags::MaxKey) {
// The behavior of PathTraverse when it encounters an array is to apply its subpath
@@ -628,8 +598,9 @@ private:
result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel);
}
- result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result));
+ result = translateFieldRef(*(expr->fieldRef()), std::move(result));
}
+
_ctx.push(std::move(result));
}
@@ -659,6 +630,23 @@ private:
str::stream() << "Match expression is not supported: " << expr->matchType());
}
+ /**
+ * Returns whether the currently visiting expression should consider the path it's operating on
+ * and build the appropriate ABT. This can return false for expressions within an $elemMatch
+ * that operate against each value in an array (aka "elemMatch value").
+ */
+ bool shouldGeneratePath(const PathMatchExpression* expr) const {
+ // The only case where any expression, including $elemMatch, should ignore it's path is if
+ // it's directly under a value $elemMatch. The 'elemMatchStack' includes 'expr' if it's an
+ // $elemMatch, so we need to look back an extra element.
+ if (expr->matchType() == MatchExpression::MatchType::ELEM_MATCH_OBJECT ||
+ expr->matchType() == MatchExpression::MatchType::ELEM_MATCH_VALUE) {
+ return _ctx.shouldGeneratePathForElemMatch();
+ }
+
+ return _ctx.shouldGeneratePath();
+ }
+
// If we are parsing a partial index filter, we don't allow agg expressions.
const bool _allowAggExpressions;
diff --git a/src/mongo/db/pipeline/abt/utils.cpp b/src/mongo/db/pipeline/abt/utils.cpp
index c4d6093cc1b..9911e090305 100644
--- a/src/mongo/db/pipeline/abt/utils.cpp
+++ b/src/mongo/db/pipeline/abt/utils.cpp
@@ -30,7 +30,7 @@
#include "mongo/db/pipeline/abt/utils.h"
#include "mongo/db/exec/sbe/values/bson.h"
-
+#include "mongo/db/query/optimizer/utils/utils.h"
namespace mongo::optimizer {
@@ -81,6 +81,39 @@ ABT translateFieldPath(const FieldPath& fieldPath,
return result;
}
+ABT translateFieldRef(const FieldRef& fieldRef, ABT initial) {
+ ABT result = std::move(initial);
+
+ const size_t fieldPathLength = fieldRef.numParts();
+
+ // Handle empty field paths separately.
+ if (fieldPathLength == 0) {
+ return make<PathGet>("", std::move(result));
+ }
+
+ for (size_t i = fieldPathLength; i-- > 0;) {
+ // A single empty field path will parse to a FieldRef with 0 parts but should
+ // logically be considered a single part with an empty string.
+ if (i != fieldPathLength - 1) {
+ // For field paths with empty elements such as 'x.', we should traverse the
+ // array 'x' but not reach into any sub-objects. So a predicate such as {'x.':
+ // {$eq: 5}} should match {x: [5]} and {x: {"": 5}} but not {x: [{"": 5}]}.
+ const bool trailingEmptyPath =
+ (fieldPathLength >= 2u && i == fieldPathLength - 2u) && (fieldRef[i + 1] == ""_sd);
+ if (trailingEmptyPath) {
+ auto arrCase = make<PathArr>();
+ maybeComposePath(arrCase, result.cast<PathGet>()->getPath());
+ maybeComposePath<PathComposeA>(result, arrCase);
+ } else {
+ result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel);
+ }
+ }
+ result = make<PathGet>(fieldRef[i].toString(), std::move(result));
+ }
+
+ return result;
+}
+
std::pair<boost::optional<ABT>, bool> getMinMaxBoundForType(const bool isMin,
const sbe::value::TypeTags& tag) {
switch (tag) {
diff --git a/src/mongo/db/pipeline/abt/utils.h b/src/mongo/db/pipeline/abt/utils.h
index 3d7c2906979..54a255adff7 100644
--- a/src/mongo/db/pipeline/abt/utils.h
+++ b/src/mongo/db/pipeline/abt/utils.h
@@ -40,12 +40,22 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val);
using ABTFieldNameFn =
std::function<ABT(const std::string& fieldName, const bool isLastElement, ABT input)>;
+
+/**
+ * Translates an aggregation FieldPath by invoking the `fieldNameFn` for each path component.
+ */
ABT translateFieldPath(const FieldPath& fieldPath,
ABT initial,
const ABTFieldNameFn& fieldNameFn,
size_t skipFromStart = 0);
/**
+ * Translates a given FieldRef (typically used in a MatchExpression) with 'initial' as the input
+ * ABT.
+ */
+ABT translateFieldRef(const FieldRef& fieldRef, ABT initial);
+
+/**
* Return the minimum or maximum value for the "class" of values represented by the input
* constant. Used to support type bracketing.
* Return format is <min/max value, bool inclusive>
diff --git a/src/mongo/db/pipeline/aggregate_command.idl b/src/mongo/db/pipeline/aggregate_command.idl
index 5e3cfc5ba5d..9cb77333eef 100644
--- a/src/mongo/db/pipeline/aggregate_command.idl
+++ b/src/mongo/db/pipeline/aggregate_command.idl
@@ -286,4 +286,9 @@ commands:
type: array<ExternalDataSourceOption>
cpp_name: externalDataSources
optional: true
- stability: unstable \ No newline at end of file
+ stability: unstable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
diff --git a/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp b/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp
index f9c59747448..cb2afcf06e2 100644
--- a/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp
+++ b/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp
@@ -86,6 +86,7 @@ public:
}
void run() {
+ LOGV2(7080100, "Starting Change Stream Expired Pre-images Remover thread");
ThreadClient tc(name(), getGlobalServiceContext());
AuthorizationSession::get(cc())->grantInternalAuthorization(&cc());
diff --git a/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp b/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp
index e93e68aa0ca..d0fc5305029 100644
--- a/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp
+++ b/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp
@@ -206,10 +206,6 @@ protected:
if (!_collScan) {
_collScan = std::make_unique<CollectionScan>(
pExpCtx.get(), _collectionPtr, _params, &_ws, _filter.get());
- // The first call to doWork will create the cursor and return NEED_TIME. But it won't
- // actually scan any of the documents that are present in the mock cursor queue.
- ASSERT_EQ(_collScan->doWork(nullptr), PlanStage::NEED_TIME);
- ASSERT_EQ(_getNumDocsTested(), 0);
}
while (true) {
// If the next result is a pause, return it and don't collscan.
diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp
index 9ef6d892a51..89959bb0b49 100644
--- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp
+++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp
@@ -46,6 +46,7 @@
#include "mongo/db/matcher/expression_geo.h"
#include "mongo/db/matcher/expression_internal_bucket_geo_within.h"
#include "mongo/db/matcher/expression_internal_expr_comparison.h"
+#include "mongo/db/matcher/match_expression_dependencies.h"
#include "mongo/db/pipeline/accumulator_multi.h"
#include "mongo/db/pipeline/document_source_add_fields.h"
#include "mongo/db/pipeline/document_source_geo_near.h"
@@ -250,6 +251,35 @@ DocumentSourceInternalUnpackBucket::DocumentSourceInternalUnpackBucket(
_bucketUnpacker(std::move(bucketUnpacker)),
_bucketMaxSpanSeconds{bucketMaxSpanSeconds} {}
+DocumentSourceInternalUnpackBucket::DocumentSourceInternalUnpackBucket(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ BucketUnpacker bucketUnpacker,
+ int bucketMaxSpanSeconds,
+ const boost::optional<BSONObj>& eventFilterBson,
+ const boost::optional<BSONObj>& wholeBucketFilterBson,
+ bool assumeNoMixedSchemaData)
+ : DocumentSourceInternalUnpackBucket(
+ expCtx, std::move(bucketUnpacker), bucketMaxSpanSeconds, assumeNoMixedSchemaData) {
+ if (eventFilterBson) {
+ _eventFilterBson = eventFilterBson->getOwned();
+ _eventFilter =
+ uassertStatusOK(MatchExpressionParser::parse(_eventFilterBson,
+ pExpCtx,
+ ExtensionsCallbackNoop(),
+ Pipeline::kAllowedMatcherFeatures));
+ _eventFilterDeps = {};
+ match_expression::addDependencies(_eventFilter.get(), &_eventFilterDeps);
+ }
+ if (wholeBucketFilterBson) {
+ _wholeBucketFilterBson = wholeBucketFilterBson->getOwned();
+ _wholeBucketFilter =
+ uassertStatusOK(MatchExpressionParser::parse(_wholeBucketFilterBson,
+ pExpCtx,
+ ExtensionsCallbackNoop(),
+ Pipeline::kAllowedMatcherFeatures));
+ }
+}
+
boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createFromBsonInternal(
BSONElement specElem, const boost::intrusive_ptr<ExpressionContext>& expCtx) {
uassert(5346500,
@@ -267,6 +297,8 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF
auto bucketMaxSpanSeconds = 0;
auto assumeClean = false;
std::vector<std::string> computedMetaProjFields;
+ boost::optional<BSONObj> eventFilterBson;
+ boost::optional<BSONObj> wholeBucketFilterBson;
for (auto&& elem : specElem.embeddedObject()) {
auto fieldName = elem.fieldNameStringData();
if (fieldName == kInclude || fieldName == kExclude) {
@@ -360,6 +392,18 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF
<< " field must be a bool, got: " << elem.type(),
elem.type() == BSONType::Bool);
bucketSpec.setUsesExtendedRange(elem.boolean());
+ } else if (fieldName == kEventFilter) {
+ uassert(7026902,
+ str::stream() << kEventFilter
+ << " field must be an object, got: " << elem.type(),
+ elem.type() == BSONType::Object);
+ eventFilterBson = elem.Obj();
+ } else if (fieldName == kWholeBucketFilter) {
+ uassert(7026903,
+ str::stream() << kWholeBucketFilter
+ << " field must be an object, got: " << elem.type(),
+ elem.type() == BSONType::Object);
+ wholeBucketFilterBson = elem.Obj();
} else {
uasserted(5346506,
str::stream()
@@ -378,6 +422,8 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF
expCtx,
BucketUnpacker{std::move(bucketSpec), unpackerBehavior},
bucketMaxSpanSeconds,
+ eventFilterBson,
+ wholeBucketFilterBson,
assumeClean);
}
@@ -476,6 +522,13 @@ void DocumentSourceInternalUnpackBucket::serializeToArray(
out.addField(kIncludeMaxTimeAsMetadata, Value{_bucketUnpacker.includeMaxTimeAsMetadata()});
}
+ if (_wholeBucketFilter) {
+ out.addField(kWholeBucketFilter, Value{_wholeBucketFilter->serialize()});
+ }
+ if (_eventFilter) {
+ out.addField(kEventFilter, Value{_eventFilter->serialize()});
+ }
+
if (!explain) {
array.push_back(Value(DOC(getSourceName() << out.freeze())));
if (_sampleSize) {
@@ -491,25 +544,49 @@ void DocumentSourceInternalUnpackBucket::serializeToArray(
}
}
+boost::optional<Document> DocumentSourceInternalUnpackBucket::getNextMatchingMeasure() {
+ while (_bucketUnpacker.hasNext()) {
+ auto measure = _bucketUnpacker.getNext();
+ if (_eventFilter) {
+ // MatchExpression only takes BSON documents, so we have to make one. As an
+ // optimization, only serialize the fields we need to do the match.
+ BSONObj measureBson = _eventFilterDeps.needWholeDocument
+ ? measure.toBson()
+ : document_path_support::documentToBsonWithPaths(measure, _eventFilterDeps.fields);
+ if (_bucketUnpacker.bucketMatchedQuery() || _eventFilter->matchesBSON(measureBson)) {
+ return measure;
+ }
+ } else {
+ return measure;
+ }
+ }
+ return {};
+}
+
DocumentSource::GetNextResult DocumentSourceInternalUnpackBucket::doGetNext() {
tassert(5521502, "calling doGetNext() when '_sampleSize' is set is disallowed", !_sampleSize);
// Otherwise, fallback to unpacking every measurement in all buckets until the child stage is
// exhausted.
- if (_bucketUnpacker.hasNext()) {
- return _bucketUnpacker.getNext();
+ if (auto measure = getNextMatchingMeasure()) {
+ return GetNextResult(std::move(*measure));
}
auto nextResult = pSource->getNext();
- if (nextResult.isAdvanced()) {
+ while (nextResult.isAdvanced()) {
auto bucket = nextResult.getDocument().toBson();
- _bucketUnpacker.reset(std::move(bucket));
+ auto bucketMatchedQuery = _wholeBucketFilter && _wholeBucketFilter->matchesBSON(bucket);
+ _bucketUnpacker.reset(std::move(bucket), bucketMatchedQuery);
+
uassert(5346509,
str::stream() << "A bucket with _id "
<< _bucketUnpacker.bucket()[timeseries::kBucketIdFieldName].toString()
<< " contains an empty data region",
_bucketUnpacker.hasNext());
- return _bucketUnpacker.getNext();
+ if (auto measure = getNextMatchingMeasure()) {
+ return GetNextResult(std::move(*measure));
+ }
+ nextResult = pSource->getNext();
}
return nextResult;
@@ -587,7 +664,8 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje
// Check for a viable inclusion $project after the $_internalUnpackBucket.
auto [existingProj, isInclusion] = getIncludeExcludeProjectAndType(std::next(itr)->get());
- if (isInclusion && !existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) {
+ if (!_eventFilter && isInclusion && !existingProj.isEmpty() &&
+ canInternalizeProjectObj(existingProj)) {
container->erase(std::next(itr));
return {existingProj, isInclusion};
}
@@ -595,8 +673,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje
// Attempt to get an inclusion $project representing the root-level dependencies of the pipeline
// after the $_internalUnpackBucket. If this $project is not empty, then the dependency set was
// finite.
- Pipeline::SourceContainer restOfPipeline(std::next(itr), container->end());
- auto deps = Pipeline::getDependenciesForContainer(pExpCtx, restOfPipeline, boost::none);
+ auto deps = getRestPipelineDependencies(itr, container);
if (auto dependencyProj =
deps.toProjectionWithoutMetadata(DepsTracker::TruncateToRootLevel::yes);
!dependencyProj.isEmpty()) {
@@ -604,7 +681,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje
}
// Check for a viable exclusion $project after the $_internalUnpackBucket.
- if (!existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) {
+ if (!_eventFilter && !existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) {
container->erase(std::next(itr));
return {existingProj, isInclusion};
}
@@ -612,8 +689,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje
return {BSONObj{}, false};
}
-std::unique_ptr<MatchExpression>
-DocumentSourceInternalUnpackBucket::createPredicatesOnBucketLevelField(
+BucketSpec::BucketPredicate DocumentSourceInternalUnpackBucket::createPredicatesOnBucketLevelField(
const MatchExpression* matchExpr) const {
return BucketSpec::createPredicatesOnBucketLevelField(
matchExpr,
@@ -1005,6 +1081,16 @@ bool findSequentialDocumentCache(Pipeline::SourceContainer::iterator start,
return start != end;
}
+DepsTracker DocumentSourceInternalUnpackBucket::getRestPipelineDependencies(
+ Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) const {
+ auto deps = Pipeline::getDependenciesForContainer(
+ pExpCtx, Pipeline::SourceContainer{std::next(itr), container->end()}, boost::none);
+ if (_eventFilter) {
+ match_expression::addDependencies(_eventFilter.get(), &deps);
+ }
+ return deps;
+}
+
Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimizeAt(
Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) {
invariant(*itr == this);
@@ -1018,7 +1104,8 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi
bool haveComputedMetaField = this->haveComputedMetaField();
// Before any other rewrites for the current stage, consider reordering with $sort.
- if (auto sortPtr = dynamic_cast<DocumentSourceSort*>(std::next(itr)->get())) {
+ if (auto sortPtr = dynamic_cast<DocumentSourceSort*>(std::next(itr)->get());
+ sortPtr && !_eventFilter) {
if (auto metaField = _bucketUnpacker.bucketSpec().metaField();
metaField && !haveComputedMetaField) {
if (checkMetadataSortReorder(sortPtr->getSortKeyPattern(), metaField.value())) {
@@ -1049,7 +1136,8 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi
}
// Attempt to push geoNear on the metaField past $_internalUnpackBucket.
- if (auto nextNear = dynamic_cast<DocumentSourceGeoNear*>(std::next(itr)->get())) {
+ if (auto nextNear = dynamic_cast<DocumentSourceGeoNear*>(std::next(itr)->get());
+ nextNear && !_eventFilter) {
// Currently we only support geo indexes on the meta field, and we enforce this by
// requiring the key field to be set so we can check before we try to look up indexes.
auto keyField = nextNear->getKeyField();
@@ -1130,8 +1218,7 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi
{
// Check if the rest of the pipeline needs any fields. For example we might only be
// interested in $count.
- auto deps = Pipeline::getDependenciesForContainer(
- pExpCtx, Pipeline::SourceContainer{std::next(itr), container->end()}, boost::none);
+ auto deps = getRestPipelineDependencies(itr, container);
if (deps.hasNoRequirements()) {
_bucketUnpacker.setBucketSpecAndBehavior({_bucketUnpacker.bucketSpec().timeField(),
_bucketUnpacker.bucketSpec().metaField(),
@@ -1151,31 +1238,65 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi
}
// Attempt to optimize last-point type queries.
- if (!_triedLastpointRewrite && optimizeLastpoint(itr, container)) {
+ if (!_triedLastpointRewrite && !_eventFilter && optimizeLastpoint(itr, container)) {
_triedLastpointRewrite = true;
// If we are able to rewrite the aggregation, give the resulting pipeline a chance to
// perform further optimizations.
return container->begin();
};
- // Attempt to map predicates on bucketed fields to predicates on the control field.
- if (auto nextMatch = dynamic_cast<DocumentSourceMatch*>(std::next(itr)->get());
- nextMatch && !_triedBucketLevelFieldsPredicatesPushdown) {
- _triedBucketLevelFieldsPredicatesPushdown = true;
+ // Attempt to map predicates on bucketed fields to the predicates on the control field.
+ if (auto nextMatch = dynamic_cast<DocumentSourceMatch*>(std::next(itr)->get())) {
- if (auto match = createPredicatesOnBucketLevelField(nextMatch->getMatchExpression())) {
+ // Merge multiple following $match stages.
+ auto itrToMatch = std::next(itr);
+ while (std::next(itrToMatch) != container->end() &&
+ dynamic_cast<DocumentSourceMatch*>(std::next(itrToMatch)->get())) {
+ nextMatch->doOptimizeAt(itrToMatch, container);
+ }
+
+ auto predicates = createPredicatesOnBucketLevelField(nextMatch->getMatchExpression());
+
+ // Try to create a tight bucket predicate to perform bucket level matching.
+ if (predicates.tightPredicate) {
+ _wholeBucketFilterBson = predicates.tightPredicate->serialize();
+ _wholeBucketFilter =
+ uassertStatusOK(MatchExpressionParser::parse(_wholeBucketFilterBson,
+ pExpCtx,
+ ExtensionsCallbackNoop(),
+ Pipeline::kAllowedMatcherFeatures));
+ _wholeBucketFilter = MatchExpression::optimize(std::move(_wholeBucketFilter));
+ }
+
+ // Push the original event predicate into the unpacking stage.
+ _eventFilterBson = nextMatch->getQuery().getOwned();
+ _eventFilter =
+ uassertStatusOK(MatchExpressionParser::parse(_eventFilterBson,
+ pExpCtx,
+ ExtensionsCallbackNoop(),
+ Pipeline::kAllowedMatcherFeatures));
+ _eventFilter = MatchExpression::optimize(std::move(_eventFilter));
+ _eventFilterDeps = {};
+ match_expression::addDependencies(_eventFilter.get(), &_eventFilterDeps);
+ container->erase(std::next(itr));
+
+ // Create a loose bucket predicate and push it before the unpacking stage.
+ if (predicates.loosePredicate) {
BSONObjBuilder bob;
- match->serialize(&bob);
+ predicates.loosePredicate->serialize(&bob);
container->insert(itr, DocumentSourceMatch::create(bob.obj(), pExpCtx));
// Give other stages a chance to optimize with the new $match.
return std::prev(itr) == container->begin() ? std::prev(itr)
: std::prev(std::prev(itr));
}
+
+ // We have removed a $match after this stage, so we try to optimize this stage again.
+ return itr;
}
// Attempt to push down a $project on the metaField past $_internalUnpackBucket.
- if (!haveComputedMetaField) {
+ if (!_eventFilter && !haveComputedMetaField) {
if (auto [metaProject, deleteRemainder] = extractProjectForPushDown(std::next(itr)->get());
!metaProject.isEmpty()) {
container->insert(itr,
@@ -1194,7 +1315,7 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi
// Attempt to extract computed meta projections from subsequent $project, $addFields, or $set
// and push them before the $_internalunpackBucket.
- if (pushDownComputedMetaProjection(itr, container)) {
+ if (!_eventFilter && pushDownComputedMetaProjection(itr, container)) {
// We've pushed down and removed a stage after this one. Try to optimize the new stage.
return std::prev(itr) == container->begin() ? std::prev(itr) : std::prev(std::prev(itr));
}
diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h
index 4dd93046533..a5dab5462ad 100644
--- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h
+++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h
@@ -51,6 +51,8 @@ public:
static constexpr StringData kBucketMaxSpanSeconds = "bucketMaxSpanSeconds"_sd;
static constexpr StringData kIncludeMinTimeAsMetadata = "includeMinTimeAsMetadata"_sd;
static constexpr StringData kIncludeMaxTimeAsMetadata = "includeMaxTimeAsMetadata"_sd;
+ static constexpr StringData kWholeBucketFilter = "wholeBucketFilter"_sd;
+ static constexpr StringData kEventFilter = "eventFilter"_sd;
static boost::intrusive_ptr<DocumentSource> createFromBsonInternal(
BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -62,6 +64,13 @@ public:
int bucketMaxSpanSeconds,
bool assumeNoMixedSchemaData = false);
+ DocumentSourceInternalUnpackBucket(const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ BucketUnpacker bucketUnpacker,
+ int bucketMaxSpanSeconds,
+ const boost::optional<BSONObj>& eventFilterBson,
+ const boost::optional<BSONObj>& wholeBucketFilterBson,
+ bool assumeNoMixedSchemaData = false);
+
const char* getSourceName() const override {
return kStageNameInternal.rawData();
}
@@ -158,7 +167,7 @@ public:
/**
* Convenience wrapper around BucketSpec::createPredicatesOnBucketLevelField().
*/
- std::unique_ptr<MatchExpression> createPredicatesOnBucketLevelField(
+ BucketSpec::BucketPredicate createPredicatesOnBucketLevelField(
const MatchExpression* matchExpr) const;
/**
@@ -243,8 +252,14 @@ public:
GetModPathsReturn getModifiedPaths() const final override;
+ DepsTracker getRestPipelineDependencies(Pipeline::SourceContainer::iterator itr,
+ Pipeline::SourceContainer* container) const;
+
private:
GetNextResult doGetNext() final;
+
+ boost::optional<Document> getNextMatchingMeasure();
+
bool haveComputedMetaField() const;
// If buckets contained a mixed type schema along some path, we have to push down special
@@ -261,9 +276,13 @@ private:
int _bucketMaxCount = 0;
boost::optional<long long> _sampleSize;
- // Used to avoid infinite loops after we step backwards to optimize a $match on bucket level
- // fields, otherwise we may do an infinite number of $match pushdowns.
- bool _triedBucketLevelFieldsPredicatesPushdown = false;
+ // Filters pushed from the later $match stages
+ std::unique_ptr<MatchExpression> _eventFilter;
+ BSONObj _eventFilterBson;
+ DepsTracker _eventFilterDeps;
+ std::unique_ptr<MatchExpression> _wholeBucketFilter;
+ BSONObj _wholeBucketFilterBson;
+
bool _optimizedEndOfPipeline = false;
bool _triedInternalizeProject = false;
bool _triedLastpointRewrite = false;
diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp
index ffe291dcbd0..63d3b4a0b23 100644
--- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp
+++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp
@@ -55,10 +55,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.max.a': {$_internalExprGt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -76,10 +77,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.max.a': {$_internalExprGte: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -97,10 +99,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.min.a': {$_internalExprLt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -118,10 +121,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.min.a': {$_internalExprLte: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -139,11 +143,12 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {$and:[{'control.min.a': {$_internalExprLte: 1}},"
"{'control.max.a': {$_internalExprGte: 1}}]},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -161,7 +166,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate);
+ ASSERT(predicate.loosePredicate);
auto expected = fromjson(
"{$or: ["
" {$or: ["
@@ -185,7 +190,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
" ]}}"
" ]}"
"]}");
- ASSERT_BSONOBJ_EQ(predicate->serialize(true), expected);
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), expected);
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -208,10 +214,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.max.a': {$_internalExprGt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -234,10 +241,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.max.a': {$_internalExprGte: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -260,10 +268,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.min.a': {$_internalExprLt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -286,10 +295,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {'control.min.a': {$_internalExprLte: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -312,11 +322,12 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: [ {$and:[{'control.min.a': {$_internalExprLte: 1}},"
"{'control.max.a': {$_internalExprGte: 1}}]},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -334,13 +345,14 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$and: [ {$or: [ {'control.max.b': {$_internalExprGt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]},"
"{$type: [ \"$control.max.b\" ]} ]}} ]},"
"{$or: [ {'control.min.a': {$_internalExprLt: 5}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -358,7 +370,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -376,7 +389,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: ["
" {'control.max.b': {$_internalExprGt: 1}},"
" {$expr: {$ne: ["
@@ -384,6 +397,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
" {$type: [ \"$control.max.b\" ]}"
" ]}}"
"]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -402,7 +416,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$and: [ {$or: [ {'control.max.b': {$_internalExprGte: 2}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]},"
"{$type: [ \"$control.max.b\" ]} ]}} ]},"
@@ -412,6 +426,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
"{$or: [ {'control.min.a': {$_internalExprLt: 5}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]} ]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -429,8 +444,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate);
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT(predicate.loosePredicate);
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: ["
" {$or: ["
" {'control.max.b': {$_internalExprGt: 1}},"
@@ -447,6 +462,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
" ]}}"
" ]}"
"]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -464,7 +480,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -485,7 +502,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
// When a predicate can't be pushed down, it's the same as pushing down a trivially-true
// predicate. So when any child of an $or can't be pushed down, we could generate something like
// {$or: [ ... {$alwaysTrue: {}}, ... ]}, but then we might as well not push down the whole $or.
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -504,7 +522,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$or: ["
" {$or: ["
" {'control.max.b': {$_internalExprGte: 2}},"
@@ -530,6 +548,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
" ]}"
" ]}"
"]}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -541,19 +560,21 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
ASSERT_EQ(pipeline->getSources().size(), 2U);
pipeline->optimizePipeline();
- ASSERT_EQ(pipeline->getSources().size(), 3U);
+ ASSERT_EQ(pipeline->getSources().size(), 2U);
// To get the optimized $match from the pipeline, we have to serialize with explain.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(stages.size(), 3U);
+ ASSERT_EQ(stages.size(), 2U);
ASSERT_BSONOBJ_EQ(stages[0].getDocument().toBson(),
fromjson("{$match: {$or: [ {'control.max.b': {$_internalExprGt: 1}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]},"
"{$type: [ \"$control.max.b\" ]} ]}} ]}}"));
- ASSERT_BSONOBJ_EQ(stages[1].getDocument().toBson(), unpackBucketObj);
- ASSERT_BSONOBJ_EQ(stages[2].getDocument().toBson(),
- fromjson("{$match: {$and: [{b: {$gt: 1}}, {a: {$not: {$eq: 5}}}]}}"));
+ ASSERT_BSONOBJ_EQ(
+ stages[1].getDocument().toBson(),
+ fromjson(
+ "{$_internalUnpackBucket: {exclude: [], timeField: 'time', bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { $and: [ { b: { $gt: 1 } }, { a: { $not: { $eq: 5 } } } ] }}}"));
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -566,10 +587,10 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
ASSERT_EQ(pipeline->getSources().size(), 2U);
pipeline->optimizePipeline();
- ASSERT_EQ(pipeline->getSources().size(), 3U);
+ ASSERT_EQ(pipeline->getSources().size(), 2U);
auto stages = pipeline->serializeToBson();
- ASSERT_EQ(stages.size(), 3U);
+ ASSERT_EQ(stages.size(), 2U);
ASSERT_BSONOBJ_EQ(
stages[0],
@@ -582,8 +603,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
"{$or: [ {'control.min.a': {$_internalExprLt: 5}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"));
- ASSERT_BSONOBJ_EQ(stages[1], unpackBucketObj);
- ASSERT_BSONOBJ_EQ(stages[2], matchObj);
+ ASSERT_BSONOBJ_EQ(stages[1],
+ fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "bucketMaxSpanSeconds: 3600,"
+ "eventFilter: { $and: [ { b: { $gte: 2 } }, { c: { $gt: 1 } }, { a: "
+ "{ $lt: 5 } } ] } } }"));
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -601,7 +625,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -619,7 +644,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -637,7 +663,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT(predicate == nullptr);
+ ASSERT(predicate.loosePredicate == nullptr);
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -656,7 +683,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
->createPredicatesOnBucketLevelField(original->getMatchExpression());
// Meta predicates are mapped to the meta field, not the control min/max fields.
- ASSERT_BSONOBJ_EQ(predicate->serialize(true), fromjson("{meta: {$gt: 5}}"));
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{meta: {$gt: 5}}"));
+ ASSERT_BSONOBJ_EQ(predicate.tightPredicate->serialize(true), fromjson("{meta: {$gt: 5}}"));
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -675,7 +703,10 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
->createPredicatesOnBucketLevelField(original->getMatchExpression());
// Meta predicates are mapped to the meta field, not the control min/max fields.
- ASSERT_BSONOBJ_EQ(predicate->serialize(true), fromjson("{'meta.foo': {$gt: 5}}"));
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
+ fromjson("{'meta.foo': {$gt: 5}}"));
+ ASSERT_BSONOBJ_EQ(predicate.tightPredicate->serialize(true),
+ fromjson("{'meta.foo': {$gt: 5}}"));
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
@@ -693,7 +724,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$and: ["
" {$or: ["
" {'control.max.a': {$_internalExprGt: 1}},"
@@ -704,6 +735,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
" ]},"
" {meta: {$eq: 5}}"
"]}"));
+ ASSERT(predicate.tightPredicate == nullptr);
}
TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePredicatesOnId) {
@@ -740,7 +772,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get());
+ auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get());
auto children = andExpr->getChildVector();
ASSERT_EQ(children->size(), 3);
@@ -797,7 +829,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get());
+ auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get());
auto children = andExpr->getChildVector();
ASSERT_EQ(children->size(), 3);
@@ -846,7 +878,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get());
+ auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get());
auto children = andExpr->getChildVector();
ASSERT_EQ(children->size(), 6);
@@ -908,7 +940,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get());
+ auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get());
auto children = andExpr->getChildVector();
ASSERT_EQ(children->size(), 3);
@@ -957,7 +989,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get());
+ auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get());
auto children = andExpr->getChildVector();
ASSERT_EQ(children->size(), 3);
@@ -1000,7 +1032,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_FALSE(predicate);
+ ASSERT_FALSE(predicate.loosePredicate);
+ ASSERT_FALSE(predicate.tightPredicate);
}
}
{
@@ -1021,7 +1054,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_FALSE(predicate);
+ ASSERT_FALSE(predicate.loosePredicate);
+ ASSERT_FALSE(predicate.tightPredicate);
}
}
{
@@ -1042,7 +1076,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_FALSE(predicate);
+ ASSERT_FALSE(predicate.loosePredicate);
+ ASSERT_FALSE(predicate.tightPredicate);
}
}
{
@@ -1065,7 +1100,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_FALSE(predicate);
+ ASSERT_FALSE(predicate.loosePredicate);
+ ASSERT_FALSE(predicate.tightPredicate);
}
}
{
@@ -1086,7 +1122,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate =
dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_FALSE(predicate);
+ ASSERT_FALSE(predicate.loosePredicate);
+ ASSERT_FALSE(predicate.tightPredicate);
}
}
}
@@ -1107,10 +1144,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest,
auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get())
->createPredicatesOnBucketLevelField(original->getMatchExpression());
- ASSERT_BSONOBJ_EQ(predicate->serialize(true),
+ ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true),
fromjson("{$_internalBucketGeoWithin: { withinRegion: { $geometry: { type : "
"\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 "
"] ] ]}},field: \"loc\"}}"));
+ ASSERT_FALSE(predicate.tightPredicate);
}
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp
index 2001cb2cab3..aeb7ac8ec8c 100644
--- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp
+++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp
@@ -54,7 +54,7 @@ TEST_F(OptimizePipeline, MixedMatchPushedDown) {
// To get the optimized $match from the pipeline, we have to serialize with explain.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(3u, stages.size());
+ ASSERT_EQ(2u, stages.size());
// We should push down the $match on the metaField and the predicates on the control field.
// The created $match stages should be added before $_internalUnpackBucket and merged.
@@ -63,8 +63,10 @@ TEST_F(OptimizePipeline, MixedMatchPushedDown) {
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ] } ] } } ] }]}}"),
stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
+ stages[1].getDocument().toBson());
}
TEST_F(OptimizePipeline, MetaMatchPushedDown) {
@@ -103,7 +105,7 @@ TEST_F(OptimizePipeline, MixedMatchOr) {
pipeline->optimizePipeline();
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(3u, stages.size());
+ ASSERT_EQ(2u, stages.size());
auto expected = fromjson(
"{$match: {$and: ["
// Result of pushing down {x: {$lte: 1}}.
@@ -123,8 +125,11 @@ TEST_F(OptimizePipeline, MixedMatchOr) {
" ]}"
"]}}");
ASSERT_BSONOBJ_EQ(expected, stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(match, stages[2].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"foo\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { $and: [ { x: { $lte: 1 } }, { $or: [ { "
+ "\"myMeta.a\": { $gt: 1 } }, { y: { $lt: 1 } } ] } ] } } }"),
+ stages[1].getDocument().toBson());
}
TEST_F(OptimizePipeline, MixedMatchOnlyMetaMatchPushedDown) {
@@ -142,11 +147,13 @@ TEST_F(OptimizePipeline, MixedMatchOnlyMetaMatchPushedDown) {
// We should push down the $match on the metaField but not the predicate on '$a', which is
// ineligible because of the $type.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(3u, serialized.size());
+ ASSERT_EQ(2u, serialized.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$gte: 0}}, {meta: {$lte: 5}}]}}"),
serialized[0]);
- ASSERT_BSONOBJ_EQ(unpack, serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$type: [ 2 ]}}}"), serialized[2]);
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $type: [ 2 ] } } } }"),
+ serialized[1]);
}
TEST_F(OptimizePipeline, MultipleMatchesPushedDown) {
@@ -164,15 +171,17 @@ TEST_F(OptimizePipeline, MultipleMatchesPushedDown) {
// We should push down both the $match on the metaField and the predicates on the control field.
// The created $match stages should be added before $_internalUnpackBucket and merged.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(3u, stages.size());
+ ASSERT_EQ(2u, stages.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [ {meta: {$gte: 0}},"
"{meta: {$lte: 5}},"
"{$or: [ {'control.min.a': {$_internalExprLte: 4}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]},"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"),
stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
+ stages[1].getDocument().toBson());
}
TEST_F(OptimizePipeline, MultipleMatchesPushedDownWithSort) {
@@ -191,16 +200,18 @@ TEST_F(OptimizePipeline, MultipleMatchesPushedDownWithSort) {
// We should push down both the $match on the metaField and the predicates on the control field.
// The created $match stages should be added before $_internalUnpackBucket and merged.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(4u, stages.size());
+ ASSERT_EQ(3u, stages.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [ { meta: { $gte: 0 } },"
"{meta: { $lte: 5 } },"
"{$or: [ { 'control.min.a': { $_internalExprLte: 4 } },"
"{$expr: { $ne: [ {$type: [ \"$control.min.a\" ] },"
"{$type: [ \"$control.max.a\" ] } ] } } ] }]}}"),
stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$sort: {sortKey: {a: 1}}}"), stages[3].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
+ stages[1].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(fromjson("{$sort: {sortKey: {a: 1}}}"), stages[2].getDocument().toBson());
}
TEST_F(OptimizePipeline, MetaMatchThenCountPushedDown) {
@@ -261,7 +272,7 @@ TEST_F(OptimizePipeline, SortThenMixedMatchPushedDown) {
// We should push down both the $sort and parts of the $match.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(4u, serialized.size());
+ ASSERT_EQ(3u, serialized.size());
auto expected = fromjson(
"{$match: {$and: ["
" {meta: {$eq: 'abc'}},"
@@ -274,8 +285,10 @@ TEST_F(OptimizePipeline, SortThenMixedMatchPushedDown) {
"]}}");
ASSERT_BSONOBJ_EQ(expected, serialized[0]);
ASSERT_BSONOBJ_EQ(fromjson("{$sort: {meta: -1}}"), serialized[1]);
- ASSERT_BSONOBJ_EQ(unpack, serialized[2]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$gte: 5}}}"), serialized[3]);
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $gte: 5 } } } }"),
+ serialized[2]);
}
TEST_F(OptimizePipeline, MetaMatchThenSortPushedDown) {
@@ -331,18 +344,19 @@ TEST_F(OptimizePipeline, MixedMatchThenProjectPushedDown) {
// We can push down part of the $match and use dependency analysis on the end of the pipeline.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(4u, stages.size());
+ ASSERT_EQ(3u, stages.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$eq: 'abc'}},"
"{$or: [ {'control.min.a': { $_internalExprLte: 4 } },"
"{$expr: { $ne: [ {$type: [ \"$control.min.a\" ] },"
"{$type: [ \"$control.max.a\" ] } ] } } ] } ]}}"),
stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'x'], timeField: "
- "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"),
- stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(
+ fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"x\" ], timeField: "
+ "\"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
+ stages[1].getDocument().toBson());
ASSERT_BSONOBJ_EQ(fromjson("{$project: {_id: true, x: true}}"),
- stages[3].getDocument().toBson());
+ stages[2].getDocument().toBson());
}
@@ -379,21 +393,21 @@ TEST_F(OptimizePipeline, ProjectThenMixedMatchPushedDown) {
// We should push down part of the $match and do dependency analysis on the rest.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(4u, stages.size());
+ ASSERT_EQ(3u, stages.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$eq: \"abc\"}},"
"{$or: [ {'control.min.a': {$_internalExprLte: 4}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ] },"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"),
stages[0].getDocument().toBson());
ASSERT_BSONOBJ_EQ(
- fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'x', 'myMeta'], timeField: "
- "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"),
+ fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"x\", \"myMeta\" ], "
+ "timeField: \"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson());
const UnorderedFieldsBSONObjComparator kComparator;
ASSERT_EQ(
kComparator.compare(fromjson("{$project: {_id: true, a: true, myMeta: true, x: true}}"),
- stages[3].getDocument().toBson()),
+ stages[2].getDocument().toBson()),
0);
}
@@ -410,7 +424,7 @@ TEST_F(OptimizePipeline, ProjectWithRenameThenMixedMatchPushedDown) {
// We should push down part of the $match and do dependency analysis on the end of the pipeline.
auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner);
- ASSERT_EQ(4u, stages.size());
+ ASSERT_EQ(3u, stages.size());
ASSERT_BSONOBJ_EQ(
fromjson("{$match: {$and: [{$or: [ {'control.max.y': {$_internalExprGte: \"abc\"}},"
"{$expr: {$ne: [ {$type: [ \"$control.min.y\" ]},"
@@ -419,13 +433,13 @@ TEST_F(OptimizePipeline, ProjectWithRenameThenMixedMatchPushedDown) {
"{$expr: {$ne: [ {$type: [ \"$control.min.a\" ] },"
"{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"),
stages[0].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'y'], timeField: "
- "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"),
- stages[1].getDocument().toBson());
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{y: {$gte: 'abc'}}, {a: {$lte: 4}}]}}"),
- stages[2].getDocument().toBson());
+ ASSERT_BSONOBJ_EQ(
+ fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"y\" ], timeField: "
+ "\"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { $and: [ { y: { $gte: \"abc\" } }, { a: { $lte: 4 } } ] } } }"),
+ stages[1].getDocument().toBson());
ASSERT_BSONOBJ_EQ(fromjson("{$project: {_id: true, a: true, myMeta: '$y'}}"),
- stages[3].getDocument().toBson());
+ stages[2].getDocument().toBson());
}
TEST_F(OptimizePipeline, ComputedProjectThenMetaMatchPushedDown) {
@@ -466,15 +480,15 @@ TEST_F(OptimizePipeline, ComputedProjectThenMetaMatchNotPushedDown) {
// We should both push down the project and internalize the remaining project, but we can't
// push down the meta match due to the (now invalid) renaming.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(3u, serialized.size());
+ ASSERT_EQ(2u, serialized.size());
ASSERT_BSONOBJ_EQ(fromjson("{$addFields: {myMeta: {$sum: ['$meta.a', '$meta.b']}}}"),
serialized[0]);
ASSERT_BSONOBJ_EQ(
- fromjson(
- "{$_internalUnpackBucket: { include: ['_id', 'myMeta'], timeField: 'time', metaField: "
- "'myMeta', bucketMaxSpanSeconds: 3600, computedMetaProjFields: ['myMeta']}}"),
+ fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"myMeta\" ], timeField: "
+ "\"time\", metaField: \"myMeta\", "
+ "bucketMaxSpanSeconds: 3600, computedMetaProjFields: [ \"myMeta\" ], "
+ "eventFilter: { myMeta: { $gte: \"abc\" } } } }"),
serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {myMeta: {$gte: 'abc'}}}"), serialized[2]);
} // namespace
TEST_F(OptimizePipeline, ComputedProjectThenMatchNotPushedDown) {
@@ -491,13 +505,13 @@ TEST_F(OptimizePipeline, ComputedProjectThenMatchNotPushedDown) {
// We should push down the computed project but not the match, because it depends on the newly
// computed values.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(3u, serialized.size());
+ ASSERT_EQ(2u, serialized.size());
ASSERT_BSONOBJ_EQ(fromjson("{$addFields: {y: {$sum: ['$meta.a', '$meta.b']}}}"), serialized[0]);
- ASSERT_BSONOBJ_EQ(
- fromjson("{$_internalUnpackBucket: { include: ['_id', 'y'], timeField: 'time', metaField: "
- "'myMeta', bucketMaxSpanSeconds: 3600, computedMetaProjFields: ['y']}}"),
- serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {y: {$gt: 'abc'}}}"), serialized[2]);
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"y\" ], "
+ "timeField: \"time\", metaField: \"myMeta\", "
+ "bucketMaxSpanSeconds: 3600, computedMetaProjFields: [ \"y\" ], "
+ "eventFilter: { y: { $gt: \"abc\" } } } }"),
+ serialized[1]);
}
TEST_F(OptimizePipeline, MetaSortThenProjectPushedDown) {
@@ -857,7 +871,7 @@ TEST_F(OptimizePipeline, MatchWithGeoWithinOnMeasurementsPushedDownUsingInternal
pipeline->optimizePipeline();
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(serialized.size(), 3U);
+ ASSERT_EQ(serialized.size(), 2U);
// $match with $geoWithin on a non-metadata field is pushed down and $_internalBucketGeoWithin
// is used.
@@ -866,12 +880,12 @@ TEST_F(OptimizePipeline, MatchWithGeoWithinOnMeasurementsPushedDownUsingInternal
"\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 "
"] ] ]}},field: \"loc\"}}}"),
serialized[0]);
- ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: {exclude: [], timeField: "
- "'time', bucketMaxSpanSeconds: 3600}}"),
- serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {loc: {$geoWithin: {$geometry: {type: \"Polygon\", "
- "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ]}}}}}"),
- serialized[2]);
+ ASSERT_BSONOBJ_EQ(
+ fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { loc: { $geoWithin: { $geometry: { type: \"Polygon\", coordinates: "
+ "[ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ] } } } } } }"),
+ serialized[1]);
}
TEST_F(OptimizePipeline, MatchWithGeoWithinOnMetaFieldIsPushedDown) {
@@ -913,7 +927,7 @@ TEST_F(OptimizePipeline,
pipeline->optimizePipeline();
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(serialized.size(), 3U);
+ ASSERT_EQ(serialized.size(), 2U);
// $match with $geoIntersects on a non-metadata field is pushed down and
// $_internalBucketGeoWithin is used.
@@ -922,12 +936,12 @@ TEST_F(OptimizePipeline,
"\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 "
"] ] ]}},field: \"loc\"}}}"),
serialized[0]);
- ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: {exclude: [], timeField: "
- "'time', bucketMaxSpanSeconds: 3600}}"),
- serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {loc: {$geoIntersects: {$geometry: {type: \"Polygon\", "
- "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ]}}}}}"),
- serialized[2]);
+ ASSERT_BSONOBJ_EQ(
+ fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", "
+ "bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { loc: { $geoIntersects: { $geometry: { type: \"Polygon\", "
+ "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ] } } } } } }"),
+ serialized[1]);
}
TEST_F(OptimizePipeline, MatchWithGeoIntersectsOnMetaFieldIsPushedDown) {
diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp
index ba4f31adf17..4ce5d558ac4 100644
--- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp
+++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp
@@ -56,7 +56,7 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeSplitsMatchAndMaps
// predicate on 'control.min.a'. These two created $match stages should be added before
// $_internalUnpackBucket and merged.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(3u, serialized.size());
+ ASSERT_EQ(2u, serialized.size());
ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: ["
" {meta: {$gte: 0}},"
" {meta: {$lte: 5}},"
@@ -68,8 +68,13 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeSplitsMatchAndMaps
" ]}"
"]}}"),
serialized[0]);
- ASSERT_BSONOBJ_EQ(unpack, serialized[1]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), serialized[2]);
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { "
+ "exclude: [], "
+ "timeField: \"foo\", "
+ "metaField: \"myMeta\", "
+ "bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { a: { $lte: 4 } } } }"),
+ serialized[1]);
}
TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeMovesMetaMatchBeforeUnpack) {
@@ -94,10 +99,6 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename,
auto unpack = fromjson(
"{$_internalUnpackBucket: { exclude: [], timeField: 'foo', metaField: 'myMeta', "
"bucketMaxSpanSeconds: 3600}}");
- auto unpackExcluded = fromjson(
- "{$_internalUnpackBucket: { include: ['_id', 'data'], timeField: 'foo', metaField: "
- "'myMeta', "
- "bucketMaxSpanSeconds: 3600}}");
auto pipeline = Pipeline::parse(makeVector(unpack,
fromjson("{$project: {data: 1}}"),
fromjson("{$match: {myMeta: {$gte: 0}}}")),
@@ -108,9 +109,11 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename,
// The $match on meta is not moved before $_internalUnpackBucket since the field is excluded.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(2u, serialized.size());
- ASSERT_BSONOBJ_EQ(unpackExcluded, serialized[0]);
- ASSERT_BSONOBJ_EQ(fromjson("{$match: {myMeta: {$gte: 0}}}"), serialized[1]);
+ ASSERT_EQ(1u, serialized.size());
+ ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"data\" ], "
+ "timeField: \"foo\", metaField: \"myMeta\", bucketMaxSpanSeconds: "
+ "3600, eventFilter: { myMeta: { $gte: 0 } } } }"),
+ serialized[0]);
}
TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename,
@@ -134,7 +137,7 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename,
// We should fail to split the match because of the $or clause. We should still be able to
// map the predicate on 'x' to a predicate on the control field.
auto serialized = pipeline->serializeToBson();
- ASSERT_EQ(3u, serialized.size());
+ ASSERT_EQ(2u, serialized.size());
auto expected = fromjson(
"{$match: {$and: ["
// Result of pushing down {x: {$lte: 1}}.
@@ -154,8 +157,13 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename,
" ]}"
"]}}");
ASSERT_BSONOBJ_EQ(expected, serialized[0]);
- ASSERT_BSONOBJ_EQ(unpack, serialized[1]);
- ASSERT_BSONOBJ_EQ(match, serialized[2]);
+ ASSERT_BSONOBJ_EQ(
+ fromjson(
+ "{ $_internalUnpackBucket: { "
+ "exclude: [], timeField: \"foo\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, "
+ "eventFilter: { $and: [ { x: { $lte: 1 } }, { $or: [ { \"myMeta.a\": { $gt: 1 } }, { "
+ "y: { $lt: 1 } } ] } ] } } }"),
+ serialized[1]);
}
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index e87a39f85ac..080faffe5c3 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -2414,7 +2414,7 @@ intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::createVarFromString(
ExpressionFieldPath::ExpressionFieldPath(ExpressionContext* const expCtx,
const string& theFieldPath,
Variables::Id variable)
- : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {
+ : Expression(expCtx), _fieldPath(theFieldPath, true /*precomputeHashes*/), _variable(variable) {
const auto varName = theFieldPath.substr(0, theFieldPath.find('.'));
tassert(5943201,
std::string{"Variable with $$ROOT's id is not $$CURRENT or $$ROOT as expected, "
diff --git a/src/mongo/db/pipeline/field_path.cpp b/src/mongo/db/pipeline/field_path.cpp
index c860e863512..bdbbbabfee0 100644
--- a/src/mongo/db/pipeline/field_path.cpp
+++ b/src/mongo/db/pipeline/field_path.cpp
@@ -69,10 +69,8 @@ string FieldPath::getFullyQualifiedPath(StringData prefix, StringData suffix) {
return str::stream() << prefix << "." << suffix;
}
-FieldPath::FieldPath(std::string inputPath)
- : _fieldPath(std::move(inputPath)),
- _fieldPathDotPosition{string::npos},
- _fieldHash{kHashUninitialized} {
+FieldPath::FieldPath(std::string inputPath, bool precomputeHashes)
+ : _fieldPath(std::move(inputPath)), _fieldPathDotPosition{string::npos} {
uassert(40352, "FieldPath cannot be constructed with empty string", !_fieldPath.empty());
uassert(40353, "FieldPath must not end with a '.'.", _fieldPath[_fieldPath.size() - 1] != '.');
@@ -81,19 +79,21 @@ FieldPath::FieldPath(std::string inputPath)
size_t startPos = 0;
while (string::npos != (dotPos = _fieldPath.find('.', startPos))) {
_fieldPathDotPosition.push_back(dotPos);
- _fieldHash.push_back(kHashUninitialized);
startPos = dotPos + 1;
}
_fieldPathDotPosition.push_back(_fieldPath.size());
- // Validate the path length and the fields.
+ // Validate the path length and the fields, and precompute their hashes if requested.
const auto pathLength = getPathLength();
uassert(ErrorCodes::Overflow,
"FieldPath is too long",
pathLength <= BSONDepth::getMaxAllowableDepth());
+ _fieldHash.reserve(pathLength);
for (size_t i = 0; i < pathLength; ++i) {
- uassertValidFieldName(getFieldName(i));
+ const auto& fieldName = getFieldName(i);
+ uassertValidFieldName(fieldName);
+ _fieldHash.push_back(precomputeHashes ? FieldNameHasher()(fieldName) : kHashUninitialized);
}
}
diff --git a/src/mongo/db/pipeline/field_path.h b/src/mongo/db/pipeline/field_path.h
index d2ee93734e7..c216bdbc0a4 100644
--- a/src/mongo/db/pipeline/field_path.h
+++ b/src/mongo/db/pipeline/field_path.h
@@ -69,9 +69,11 @@ public:
*
* Field names are validated using uassertValidFieldName().
*/
- /* implicit */ FieldPath(std::string inputPath);
- /* implicit */ FieldPath(StringData inputPath) : FieldPath(inputPath.toString()) {}
- /* implicit */ FieldPath(const char* inputPath) : FieldPath(std::string(inputPath)) {}
+ /* implicit */ FieldPath(std::string inputPath, bool precomputeHashes = false);
+ /* implicit */ FieldPath(StringData inputPath, bool precomputeHashes = false)
+ : FieldPath(inputPath.toString(), precomputeHashes) {}
+ /* implicit */ FieldPath(const char* inputPath, bool precomputeHashes = false)
+ : FieldPath(std::string(inputPath), precomputeHashes) {}
/**
* Returns the number of path elements in the field path.
@@ -117,13 +119,8 @@ public:
*/
HashedFieldName getFieldNameHashed(size_t i) const {
dassert(i < getPathLength());
- const auto begin = _fieldPathDotPosition[i] + 1;
- const auto end = _fieldPathDotPosition[i + 1];
- StringData fieldName{&_fieldPath[begin], end - begin};
- if (_fieldHash[i] == kHashUninitialized) {
- _fieldHash[i] = FieldNameHasher()(fieldName);
- }
- return HashedFieldName{fieldName, _fieldHash[i]};
+ invariant(_fieldHash[i] != kHashUninitialized);
+ return HashedFieldName{getFieldName(i), _fieldHash[i]};
}
/**
@@ -177,9 +174,9 @@ private:
// lookup.
std::vector<size_t> _fieldPathDotPosition;
- // Contains the cached hash value for the field. Will initially be set to 'kHashUninitialized',
- // and only generated when it is first retrieved via 'getFieldNameHashed'.
- mutable std::vector<size_t> _fieldHash;
+ // Contains the hash value for the field names if it was requested when creating this path.
+ // Otherwise all elements are set to 'kHashUninitialized'.
+ std::vector<size_t> _fieldHash;
static constexpr std::size_t kHashUninitialized = std::numeric_limits<std::size_t>::max();
};
diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp
index ac2173ae50c..ec84917b5c5 100644
--- a/src/mongo/db/pipeline/pipeline_d.cpp
+++ b/src/mongo/db/pipeline/pipeline_d.cpp
@@ -107,7 +107,7 @@ using write_ops::InsertCommandRequest;
namespace {
/**
- * Extracts a prefix of 'DocumentSourceGroup' and 'DocumentSourceLookUp' stages from the given
+ * Finds a prefix of 'DocumentSourceGroup' and 'DocumentSourceLookUp' stages from the given
* pipeline to prepare for pushdown of $group and $lookup into the inner query layer so that it
* can be executed using SBE.
* Group stages are extracted from the pipeline when all of the following conditions are met:
@@ -121,11 +121,11 @@ namespace {
* - The $lookup uses only the 'localField'/'foreignField' syntax (no pipelines).
* - The foreign collection is neither sharded nor a view.
*/
-std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleStagesForPushdown(
+std::vector<std::unique_ptr<InnerPipelineStageInterface>> findSbeCompatibleStagesForPushdown(
const intrusive_ptr<ExpressionContext>& expCtx,
const MultipleCollectionAccessor& collections,
const CanonicalQuery* cq,
- Pipeline* pipeline) {
+ const Pipeline* pipeline) {
// We will eventually use the extracted group stages to populate 'CanonicalQuery::pipeline'
// which requires stages to be wrapped in an interface.
std::vector<std::unique_ptr<InnerPipelineStageInterface>> stagesForPushdown;
@@ -140,7 +140,7 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt
return {};
}
- auto&& sources = pipeline->getSources();
+ const auto& sources = pipeline->getSources();
bool isMainCollectionSharded = false;
if (const auto& mainColl = collections.getMainCollection()) {
@@ -158,7 +158,7 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt
internalQuerySlotBasedExecutionDisableLookupPushdown.load() || isMainCollectionSharded ||
collections.isAnySecondaryNamespaceAViewOrSharded();
- for (auto itr = sources.begin(); itr != sources.end();) {
+ for (auto itr = sources.begin(); itr != sources.end(); ++itr) {
const bool isLastSource = itr->get() == sources.back().get();
// $group pushdown logic.
@@ -170,7 +170,6 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt
if (groupStage->sbeCompatible() && !groupStage->doingMerge()) {
stagesForPushdown.push_back(
std::make_unique<InnerPipelineStageImpl>(groupStage, isLastSource));
- sources.erase(itr++);
continue;
}
break;
@@ -187,7 +186,6 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt
if (lookupStage->sbeCompatible()) {
stagesForPushdown.push_back(
std::make_unique<InnerPipelineStageImpl>(lookupStage, isLastSource));
- sources.erase(itr++);
continue;
}
break;
@@ -199,6 +197,21 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt
return stagesForPushdown;
}
+/**
+ * Removes the first 'stagesToRemove' stages from the pipeline. This function is meant to be paired
+ * with a call to findSbeCompatibleStagesForPushdown() - the caller must first get the stages to
+ * push down, then remove them.
+ */
+void trimPipelineStages(Pipeline* pipeline, size_t stagesToRemove) {
+ auto& sources = pipeline->getSources();
+ tassert(7087104,
+ "stagesToRemove must be <= number of pipeline sources",
+ stagesToRemove <= sources.size());
+ for (size_t i = 0; i < stagesToRemove; ++i) {
+ sources.erase(sources.begin());
+ }
+}
+
StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> attemptToGetExecutor(
const intrusive_ptr<ExpressionContext>& expCtx,
const MultipleCollectionAccessor& collections,
@@ -296,9 +309,15 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> attemptToGetExe
return getExecutorFind(expCtx->opCtx,
collections,
std::move(cq.getValue()),
- [&](auto* canonicalQuery) {
- canonicalQuery->setPipeline(extractSbeCompatibleStagesForPushdown(
- expCtx, collections, canonicalQuery, pipeline));
+ [&](auto* canonicalQuery, bool attachOnly) {
+ if (attachOnly) {
+ canonicalQuery->setPipeline(findSbeCompatibleStagesForPushdown(
+ expCtx, collections, canonicalQuery, pipeline));
+ } else {
+ // Not attaching - we need to trim the already pushed down
+ // pipeline stages from the pipeline.
+ trimPipelineStages(pipeline, canonicalQuery->pipeline().size());
+ }
},
permitYield,
plannerOpts);
diff --git a/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp b/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp
index 0731249fa29..e8e109aee8f 100644
--- a/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp
+++ b/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp
@@ -422,9 +422,8 @@ CommonMongodProcessInterface::attachCursorSourceToPipelineForLocalRead(Pipeline*
}
boost::optional<AutoGetCollectionForReadCommandMaybeLockFree> autoColl;
- const NamespaceStringOrUUID nsOrUUID = expCtx->uuid
- ? NamespaceStringOrUUID{expCtx->ns.db().toString(), *expCtx->uuid}
- : expCtx->ns;
+ const NamespaceStringOrUUID nsOrUUID =
+ expCtx->uuid ? NamespaceStringOrUUID{expCtx->ns.dbName(), *expCtx->uuid} : expCtx->ns;
// Reparse 'pipeline' to discover whether there are secondary namespaces that we need to lock
// when constructing our query executor.
diff --git a/src/mongo/db/prepare_conflict_tracker.cpp b/src/mongo/db/prepare_conflict_tracker.cpp
index e8e031a1405..ac9c1b495fc 100644
--- a/src/mongo/db/prepare_conflict_tracker.cpp
+++ b/src/mongo/db/prepare_conflict_tracker.cpp
@@ -52,7 +52,7 @@ void PrepareConflictTracker::endPrepareConflict(OperationContext* opCtx) {
auto curTick = tickSource->getTicks();
auto curConflictDuration =
- tickSource->spanTo<Microseconds>(_prepareConflictStartTime, curTick);
+ tickSource->ticksTo<Microseconds>(curTick - _prepareConflictStartTime);
_prepareConflictDuration.store(_prepareConflictDuration.load() + curConflictDuration);
_prepareConflictStartTime = 0;
diff --git a/src/mongo/db/query/canonical_query.h b/src/mongo/db/query/canonical_query.h
index e939c248a9d..fe0599a1b6a 100644
--- a/src/mongo/db/query/canonical_query.h
+++ b/src/mongo/db/query/canonical_query.h
@@ -233,6 +233,14 @@ public:
return _sbeCompatible;
}
+ void setUseCqfIfEligible(bool useCqfIfEligible) {
+ _useCqfIfEligible = useCqfIfEligible;
+ }
+
+ bool useCqfIfEligible() const {
+ return _useCqfIfEligible;
+ }
+
bool isParameterized() const {
return !_inputParamIdToExpressionMap.empty();
}
@@ -318,6 +326,16 @@ private:
// True if this query can be executed by the SBE.
bool _sbeCompatible = false;
+ // If true, indicates that we should use CQF if this query is eligible (see the
+ // isEligibleForBonsai() function for eligiblitly requirements).
+ // If false, indicates that we shouldn't use CQF even if this query is eligible. This is used to
+ // prevent hybrid classic and CQF plans in the following cases:
+ // 1. A pipeline that is not eligible for CQF but has an eligible prefix pushed down to find.
+ // 2. A subpipeline pushed down to find as part of a $lookup or $graphLookup.
+ // The default value of false ensures that only codepaths (find command) which opt-in are able
+ // to use CQF.
+ bool _useCqfIfEligible = false;
+
// True if this query must produce a RecordId output in addition to the BSON objects that
// constitute the result set of the query. Any generated query solution must not discard record
// ids, even if the optimizer detects that they are not going to be consumed downstream.
diff --git a/src/mongo/db/query/ce/ce_heuristic_test.cpp b/src/mongo/db/query/ce/ce_heuristic_test.cpp
index d1549a88daa..1d5cec6b8fc 100644
--- a/src/mongo/db/query/ce/ce_heuristic_test.cpp
+++ b/src/mongo/db/query/ce/ce_heuristic_test.cpp
@@ -31,7 +31,6 @@
#include "mongo/db/query/ce/ce_test_utils.h"
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h"
#include "mongo/db/query/optimizer/cascades/memo.h"
#include "mongo/db/query/optimizer/defs.h"
diff --git a/src/mongo/db/query/ce/ce_test_utils.cpp b/src/mongo/db/query/ce/ce_test_utils.cpp
index 349baadd549..32d881e308b 100644
--- a/src/mongo/db/query/ce/ce_test_utils.cpp
+++ b/src/mongo/db/query/ce/ce_test_utils.cpp
@@ -32,12 +32,12 @@
#include "mongo/db/query/ce/ce_test_utils.h"
#include "mongo/db/pipeline/abt/utils.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/metadata_factory.h"
#include "mongo/db/query/optimizer/opt_phase_manager.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
#include "mongo/db/query/optimizer/utils/unit_test_pipeline_utils.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/db/query/sbe_stage_builder_helpers.h"
#include "mongo/unittest/unittest.h"
@@ -81,7 +81,7 @@ optimizer::CEType CETester::getCE(ABT& abt, std::function<bool(const ABT&)> node
false /*requireRID*/,
_metadata,
getCETransport(),
- std::make_unique<DefaultCosting>(),
+ makeCosting(),
defaultConvertPathToInterval,
ConstEval::constFold,
DebugInfo::kDefaultForTests,
diff --git a/src/mongo/db/query/count_command.idl b/src/mongo/db/query/count_command.idl
index d511cd879fb..0ddb5887b0d 100644
--- a/src/mongo/db/query/count_command.idl
+++ b/src/mongo/db/query/count_command.idl
@@ -130,3 +130,8 @@ commands:
description: "Indicates whether the operation is a mirrored read"
type: optionalBool
stability: unstable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
diff --git a/src/mongo/db/query/cqf_command_utils.cpp b/src/mongo/db/query/cqf_command_utils.cpp
index 8863a22adde..98e3ec34a13 100644
--- a/src/mongo/db/query/cqf_command_utils.cpp
+++ b/src/mongo/db/query/cqf_command_utils.cpp
@@ -737,7 +737,7 @@ bool isEligibleForBonsai(const CanonicalQuery& cq,
!request.getReadOnce() && !request.getShowRecordId() && !request.getTerm();
// Early return to avoid unnecessary work of walking the input expression.
- if (!commandOptionsEligible) {
+ if (!commandOptionsEligible || !cq.useCqfIfEligible()) {
return false;
}
diff --git a/src/mongo/db/query/cqf_get_executor.cpp b/src/mongo/db/query/cqf_get_executor.cpp
index 3508553b591..cc1c42bc3af 100644
--- a/src/mongo/db/query/cqf_get_executor.cpp
+++ b/src/mongo/db/query/cqf_get_executor.cpp
@@ -56,9 +56,12 @@
#include "mongo/db/query/yield_policy_callbacks_impl.h"
#include "mongo/logv2/log.h"
#include "mongo/logv2/log_attr.h"
+#include "mongo/util/fail_point.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery
+MONGO_FAIL_POINT_DEFINE(failConstructingBonsaiExecutor);
+
namespace mongo {
using namespace optimizer;
using cost_model::CostEstimator;
@@ -638,6 +641,9 @@ std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOp
const boost::optional<BSONObj>& indexHint,
std::unique_ptr<Pipeline, PipelineDeleter> pipeline,
std::unique_ptr<CanonicalQuery> canonicalQuery) {
+ if (MONGO_unlikely(failConstructingBonsaiExecutor.shouldFail())) {
+ uasserted(620340, "attempting to use CQF while it is disabled");
+ }
// Ensure that either pipeline or canonicalQuery is set.
tassert(624070,
"getSBEExecutorViaCascadesOptimizer expects exactly one of the following to be set: "
diff --git a/src/mongo/db/query/distinct_command.idl b/src/mongo/db/query/distinct_command.idl
index a632ef06ffa..7dbb2624125 100644
--- a/src/mongo/db/query/distinct_command.idl
+++ b/src/mongo/db/query/distinct_command.idl
@@ -45,7 +45,7 @@ commands:
description: "The field path for which to return distinct values."
type: string
query:
- description: "Optional query that filters the documents from which to retrieve the
+ description: "Optional query that filters the documents from which to retrieve the
distinct values."
type: object
optional: true
@@ -56,3 +56,7 @@ commands:
mirrored:
description: "Indicates whether the operation is a mirrored read"
type: optionalBool
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
diff --git a/src/mongo/db/query/find_command.idl b/src/mongo/db/query/find_command.idl
index 44e7b02ceb9..c7a157dc810 100644
--- a/src/mongo/db/query/find_command.idl
+++ b/src/mongo/db/query/find_command.idl
@@ -243,4 +243,9 @@ commands:
description: "Indicates whether the operation is a mirrored read"
type: optionalBool
stability: unstable
+ sampleId:
+ description: "The unique sample id for the operation if it has been chosen for sampling."
+ type: uuid
+ optional: true
+ stability: unstable
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h
index 5fcac966dbf..fc78aa28014 100644
--- a/src/mongo/db/query/fle/encrypted_predicate.h
+++ b/src/mongo/db/query/fle/encrypted_predicate.h
@@ -40,6 +40,7 @@
#include "mongo/db/matcher/expression_leaf.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/fle/query_rewriter_interface.h"
+#include "mongo/stdx/unordered_map.h"
/**
* This file contains an abstract class that describes rewrites on agg Expressions and
@@ -194,33 +195,37 @@ private:
* are keyed on the dynamic type for the Expression subclass.
*/
-using ExpressionToRewriteMap = stdx::unordered_map<
- std::type_index,
- std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>;
+using ExpressionRewriteFunction =
+ std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>;
+using ExpressionToRewriteMap =
+ stdx::unordered_map<std::type_index, std::vector<ExpressionRewriteFunction>>;
extern ExpressionToRewriteMap aggPredicateRewriteMap;
-using MatchTypeToRewriteMap = stdx::unordered_map<
- MatchExpression::MatchType,
- std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>;
+using MatchRewriteFunction =
+ std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>;
+using MatchTypeToRewriteMap =
+ stdx::unordered_map<MatchExpression::MatchType, std::vector<MatchRewriteFunction>>;
extern MatchTypeToRewriteMap matchPredicateRewriteMap;
/**
* Register an agg rewrite if a condition is true at startup time.
*/
-#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \
- MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \
- \
- invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \
- aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \
- if (isEnabledExpr) { \
- return rewriteClass{rewriter}.rewrite(expr); \
- } else { \
- return std::unique_ptr<Expression>(nullptr); \
- } \
- }; \
- }
+#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \
+ MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className##_##rewriteClass) \
+ (InitializerContext*) { \
+ if (aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()) { \
+ aggPredicateRewriteMap[typeid(className)] = std::vector<ExpressionRewriteFunction>(); \
+ } \
+ aggPredicateRewriteMap[typeid(className)].push_back([](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<Expression>(nullptr); \
+ } \
+ }); \
+ };
/**
* Register an agg rewrite unconditionally.
@@ -239,17 +244,21 @@ extern MatchTypeToRewriteMap matchPredicateRewriteMap;
* Register a MatchExpression rewrite if a condition is true at startup time.
*/
#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \
- MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \
- \
- invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \
- matchPredicateRewriteMap.end()); \
- matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \
- if (isEnabledExpr) { \
- return rewriteClass{rewriter}.rewrite(expr); \
- } else { \
- return std::unique_ptr<MatchExpression>(nullptr); \
- } \
- }; \
+ MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType##_##rewriteClass) \
+ (InitializerContext*) { \
+ if (matchPredicateRewriteMap.find(MatchExpression::matchType) == \
+ matchPredicateRewriteMap.end()) { \
+ matchPredicateRewriteMap[MatchExpression::matchType] = \
+ std::vector<MatchRewriteFunction>(); \
+ } \
+ matchPredicateRewriteMap[MatchExpression::matchType].push_back( \
+ [](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<MatchExpression>(nullptr); \
+ } \
+ }); \
};
/**
* Register a MatchExpression rewrite unconditionally.
diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp
index 441f436ec00..da4fb05ecc7 100644
--- a/src/mongo/db/query/fle/query_rewriter.cpp
+++ b/src/mongo/db/query/fle/query_rewriter.cpp
@@ -40,12 +40,15 @@ public:
: queryRewriter(queryRewriter), exprRewrites(exprRewrites){};
std::unique_ptr<Expression> postVisit(Expression* exp) {
- if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) {
- auto expr = rewrite->second(queryRewriter, exp);
- if (expr != nullptr) {
- didRewrite = true;
+ if (auto rewriteEntry = exprRewrites.find(typeid(*exp));
+ rewriteEntry != exprRewrites.end()) {
+ for (auto& rewrite : rewriteEntry->second) {
+ auto expr = rewrite(queryRewriter, exp);
+ if (expr != nullptr) {
+ didRewrite = true;
+ return expr;
+ }
}
- return expr;
}
return nullptr;
}
@@ -109,13 +112,17 @@ std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr)
return nullptr;
}
default: {
- if (auto rewrite = _matchRewrites.find(expr->matchType());
- rewrite != _matchRewrites.end()) {
- auto rewritten = rewrite->second(this, expr);
- if (rewritten != nullptr) {
- _rewroteLastExpression = true;
+ if (auto rewriteEntry = _matchRewrites.find(expr->matchType());
+ rewriteEntry != _matchRewrites.end()) {
+ for (auto& rewrite : rewriteEntry->second) {
+ auto rewritten = rewrite(this, expr);
+ // Only one rewrite can be applied to an expression, so return as soon as a
+ // rewrite returns something other than nullptr.
+ if (rewritten != nullptr) {
+ _rewroteLastExpression = true;
+ return rewritten;
+ }
}
- return rewritten;
}
return nullptr;
}
diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp
index d779f6be1cc..057bae0dd2b 100644
--- a/src/mongo/db/query/fle/query_rewriter_test.cpp
+++ b/src/mongo/db/query/fle/query_rewriter_test.cpp
@@ -71,7 +71,7 @@ protected:
if (!elt.isABSONObj()) {
return false;
}
- return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd;
+ return elt.Obj().hasField("encrypt"_sd);
}
bool isPayload(const Value& v) const override {
if (!v.isObject()) {
@@ -97,9 +97,110 @@ protected:
};
std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override {
+ auto eqMatch = dynamic_cast<ExpressionCompare*>(expr);
+ invariant(eqMatch);
+ // Only operate over equality comparisons.
+ if (eqMatch->getOp() != ExpressionCompare::EQ) {
+ return nullptr;
+ }
+ auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get());
+ // If the comparison doesn't hold a constant, then don't rewrite.
+ if (!payload) {
+ return nullptr;
+ }
+
+ // If the constant is not considered a payload, then don't rewrite.
+ if (!isPayload(payload->getValue())) {
+ return nullptr;
+ }
+ auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(),
+ ExpressionCompare::GT);
+ cmp->addOperand(eqMatch->getOperandList()[0]);
+ cmp->addOperand(
+ ExpressionConstant::create(eqMatch->getExpressionContext(),
+ payload->getValue().getDocument().getField("encrypt")));
+ return cmp;
+ }
+
+ std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const override {
+ return nullptr;
+ }
+
+ std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override {
return nullptr;
}
+private:
+ // This method is not used in mock implementations of the EncryptedPredicate since isPayload(),
+ // which normally calls encryptedBinDataType(), is overridden to look for plain objects rather
+ // than BinData. Since this method is pure virtual on the superclass and needs to be
+ // implemented, it is set to kPlaceholder (0).
+ EncryptedBinDataType encryptedBinDataType() const override {
+ return EncryptedBinDataType::kPlaceholder;
+ }
+};
+
+// A second mock rewrite which replaces documents with the key "foo" into $lt operations. We need
+// two different rewrites that are registered on the same operator to verify that all rewrites are
+// iterated through.
+class OtherMockPredicateRewriter : public fle::EncryptedPredicate {
+public:
+ OtherMockPredicateRewriter(const fle::QueryRewriterInterface* rewriter)
+ : EncryptedPredicate(rewriter) {}
+
+protected:
+ bool isPayload(const BSONElement& elt) const override {
+ if (!elt.isABSONObj()) {
+ return false;
+ }
+ return elt.Obj().hasField("foo"_sd);
+ }
+ bool isPayload(const Value& v) const override {
+ if (!v.isObject()) {
+ return false;
+ }
+ return !v.getDocument().getField("foo").missing();
+ }
+
+ std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override {
+ return {};
+ };
+
+ // Encrypted values will be rewritten from $eq to $lt. This is an arbitrary decision just to
+ // make sure that the rewrite works properly.
+ std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override {
+ invariant(expr->matchType() == MatchExpression::EQ);
+ auto eqMatch = static_cast<EqualityMatchExpression*>(expr);
+ if (!isPayload(eqMatch->getData())) {
+ return nullptr;
+ }
+ return std::make_unique<LTMatchExpression>(eqMatch->path(),
+ eqMatch->getData().Obj().firstElement());
+ };
+
+ std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override {
+ auto eqMatch = dynamic_cast<ExpressionCompare*>(expr);
+ invariant(eqMatch);
+ if (eqMatch->getOp() != ExpressionCompare::EQ) {
+ return nullptr;
+ }
+ auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get());
+ if (!payload) {
+ return nullptr;
+ }
+
+ if (!isPayload(payload->getValue())) {
+ return nullptr;
+ }
+ auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(),
+ ExpressionCompare::LT);
+ cmp->addOperand(eqMatch->getOperandList()[0]);
+ cmp->addOperand(ExpressionConstant::create(
+ eqMatch->getExpressionContext(), payload->getValue().getDocument().getField("foo")));
+ return cmp;
+ }
+
std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
MatchExpression* expr) const override {
return nullptr;
@@ -111,7 +212,7 @@ protected:
private:
EncryptedBinDataType encryptedBinDataType() const override {
- return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere.
+ return EncryptedBinDataType::kPlaceholder;
}
};
@@ -119,8 +220,17 @@ void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match,
fle::ExpressionToRewriteMap& agg,
fle::TagMap& tags,
std::set<StringData>& encryptedFields) {
- match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) {
- return MockPredicateRewriter{rewriter}.rewrite(expr);
+ match[MatchExpression::EQ] = {
+ [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); },
+ [&](auto* rewriter, auto* expr) {
+ return OtherMockPredicateRewriter{rewriter}.rewrite(expr);
+ },
+ };
+ agg[typeid(ExpressionCompare)] = {
+ [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); },
+ [&](auto* rewriter, auto* expr) {
+ return OtherMockPredicateRewriter{rewriter}.rewrite(expr);
+ },
};
}
@@ -137,9 +247,17 @@ public:
return res ? res.value() : obj;
}
+ BSONObj rewriteAggExpressionForTest(const BSONObj& obj) {
+ auto expr = Expression::parseExpression(&_expCtx, obj, _expCtx.variablesParseState);
+ auto result = rewriteExpression(expr.get());
+ return result ? result->serialize(false).getDocument().toBson()
+ : expr->serialize(false).getDocument().toBson();
+ }
+
private:
fle::TagMap _tags;
std::set<StringData> _encryptedFields;
+ ExpressionContextForTest _expCtx;
};
class FLEServerRewriteTest : public unittest::Test {
@@ -167,22 +285,53 @@ protected:
ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \
}
+#define ASSERT_AGG_EXPRESSION_REWRITE(input, expected) \
+ auto actual = _mock->rewriteAggExpressionForTest(fromjson(input)); \
+ ASSERT_BSONOBJ_EQ(actual, fromjson(expected));
+
+#define TEST_FLE_REWRITE_AGG(name, input, expected) \
+ TEST_F(FLEServerRewriteTest, name##_AggExpression) { \
+ ASSERT_AGG_EXPRESSION_REWRITE(input, expected); \
+ }
+
TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath,
"{'user.ssn': {$eq: {encrypt: 2}}}",
"{'user.ssn': {$gt: 2}}");
+TEST_FLE_REWRITE_AGG(TopLevel_DottedPath,
+ "{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}",
+ "{$gt: ['$user.ssn', {$const: 2}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted,
"{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_BothEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {encrypt: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$gt: ['$age', {$const: 4}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted,
"{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_PartlyEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator,
"{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$lt: ['$age', {$const: {encrypt: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: {encrypt: "
+ "4}}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted,
"{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}",
"{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}");
@@ -191,6 +340,10 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Not,
"{ssn: {$not: {$eq: {encrypt: 5}}}}",
"{ssn: {$not: {$gt: 5}}}");
+TEST_FLE_REWRITE_AGG(TopLevel_Not,
+ "{$not: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}]}",
+ "{$not: [{$gt: ['$ssn', {$const: 2}]}]}")
+
TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}");
TEST_FLE_REWRITE_MATCH(
@@ -198,6 +351,12 @@ TEST_FLE_REWRITE_MATCH(
"{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}",
"{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}");
+TEST_FLE_REWRITE_AGG(NestedConjunction,
+ "{$and: [{$and: [{$eq: ['$ssn', {$const: {encrypt: 2}}]},{$eq: ['$other', "
+ "'field']}]},{$eq: ['$age',{$const: {encrypt: 4}}]}]}",
+ "{$and: [{$and: [{$gt: ['$ssn', {$const: 2}]},{$eq: ['$other', "
+ "{$const: 'field'}]}]},{$gt: ['$age',{$const: 4}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Nor,
"{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
"{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
@@ -205,5 +364,30 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Nor,
TEST_FLE_REWRITE_MATCH(TopLevel_Or,
"{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
"{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
+
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Or,
+ "{$or: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}, {$eq: ['$ssn', {$const: {encrypt: 4}}]}]}",
+ "{$or: [{$gt: ['$ssn', {$const: 2}]}, {$gt: ['$ssn', {$const: 4}]}]}")
+
+
+// Test that the rewriter will work from any rewrite registered to an expression. The test rewriter
+// has two rewrites registered on $eq.
+
+TEST_FLE_REWRITE_MATCH(OtherRewrite_Basic, "{'ssn': {$eq: {foo: 2}}}", "{'ssn': {$lt: 2}}");
+
+TEST_FLE_REWRITE_AGG(OtherRewrite_Basic,
+ "{$eq: ['$user.ssn', {$const: {foo: 2}}]}",
+ "{$lt: ['$user.ssn', {$const: 2}]}");
+
+TEST_FLE_REWRITE_MATCH(OtherRewrite_Conjunction_BothEncrypted,
+ "{$and: [{ssn: {encrypt: 2}}, {age: {foo: 4}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {age: {$lt: 4}}]}");
+
+TEST_FLE_REWRITE_AGG(
+ OtherRewrite_Conjunction_BothEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {foo: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: 4}]}]}");
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index 8c083b07dc1..dcd71777782 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
RangePredicate,
gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
+
+namespace {
+// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If
+// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range
+// predicate, then return null to the caller so that the rewrite can return null.
+std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) {
+ auto children = [&]() {
+ if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) {
+ return betweenExpr->getChildren();
+ } else {
+ auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr);
+ tassert(6720901,
+ "Range rewrite should only be called with $between or comparison operator.",
+ cmpExpr);
+ switch (cmpExpr->getOp()) {
+ case ExpressionCompare::GT:
+ case ExpressionCompare::GTE:
+ case ExpressionCompare::LT:
+ case ExpressionCompare::LTE:
+ return cmpExpr->getChildren();
+
+ case ExpressionCompare::EQ:
+ case ExpressionCompare::NE:
+ case ExpressionCompare::CMP:
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }
+ }
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }();
+ if (children.empty()) {
+ return {nullptr, Value()};
+ }
+ // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2.
+ auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
+ uassert(6720903, "first argument should be a fieldpath", fieldpath);
+ auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
+ uassert(6720904, "second argument should be a constant", secondArg);
+ auto payload = secondArg->getValue();
+ return {children[0], payload};
+}
+} // namespace
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ StringData path, ParsedFindRangePayload payload) const {
+ auto* expCtx = _rewriter->getExpressionContext();
+ return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ payload);
+}
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
+ tassert(7030501,
+ "$internalFleBetween can only be generated from a non-stub payload.",
+ !payload.isStub());
+ auto cm = payload.maxCounter;
+ ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
+ std::vector<ConstDataRange> edcTokens;
+ std::transform(std::make_move_iterator(payload.edges.value().begin()),
+ std::make_move_iterator(payload.edges.value().end()),
+ std::back_inserter(edcTokens),
+ [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
+
+ auto* expCtx = _rewriter->getExpressionContext();
+ return std::make_unique<ExpressionInternalFLEBetween>(
+ expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
+}
+
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
std::vector<PrfBlock> tags;
@@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
return makeTagDisjunction(toBSONArray(generateTags(payload)));
}
-std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) {
- auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr);
- tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr);
- auto children = betweenExpr->getChildren();
- uassert(6720902, "$between should have two children.", children.size() == 2);
-
- auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
- uassert(6720903, "first argument should be a fieldpath", fieldpath);
- auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
- uassert(6720904, "second argument should be a constant", secondArg);
- auto payload = secondArg->getValue();
- return {children[0], payload};
-}
-
std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
- auto [_, payload] = validateBetween(expr);
+ auto [fieldpath, payload] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(payload)) {
return nullptr;
}
+ if (isStub(std::ref(payload))) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
+
auto tags = toValues(generateTags(std::ref(payload)));
return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
}
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- StringData path, ParsedFindRangePayload payload) const {
- auto* expCtx = _rewriter->getExpressionContext();
- return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
- expCtx, path.toString(), expCtx->variablesParseState),
- payload);
-}
-
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
- tassert(7030501,
- "$internalFleBetween can only be generated from a non-stub payload.",
- !payload.isStub());
- auto cm = payload.maxCounter;
- ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
- std::vector<ConstDataRange> edcTokens;
- std::transform(std::make_move_iterator(payload.edges.value().begin()),
- std::make_move_iterator(payload.edges.value().end()),
- std::back_inserter(edcTokens),
- [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
-
- auto* expCtx = _rewriter->getExpressionContext();
- return std::make_unique<ExpressionInternalFLEBetween>(
- expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
-}
-
std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
MatchExpression* expr) const {
BSONElement ffp;
@@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
}
std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const {
- auto [fieldpath, ffp] = validateBetween(expr);
+ auto [fieldpath, ffp] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(ffp)) {
return nullptr;
}
auto payload = parseFindPayload<ParsedFindRangePayload>(ffp);
+ if (payload.isStub()) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
return fleBetweenFromPayload(fieldpath, payload);
}
} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h
index bc2a47fa173..5fbd7484ad0 100644
--- a/src/mongo/db/query/fle/range_predicate.h
+++ b/src/mongo/db/query/fle/range_predicate.h
@@ -55,6 +55,12 @@ protected:
return parsedPayload.isStub();
}
+ virtual bool isStub(Value elt) const {
+ auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(elt);
+ return parsedPayload.isStub();
+ }
+
+
private:
EncryptedBinDataType encryptedBinDataType() const override {
return EncryptedBinDataType::kFLE2FindRangePayload;
diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp
index f81723da5eb..61110c19041 100644
--- a/src/mongo/db/query/fle/range_predicate_test.cpp
+++ b/src/mongo/db/query/fle/range_predicate_test.cpp
@@ -68,7 +68,11 @@ protected:
return isStubPayload;
}
- std::vector<PrfBlock> generateTags(BSONValue payload) const {
+ bool isStub(Value elt) const override {
+ return isStubPayload;
+ }
+
+ std::vector<PrfBlock> generateTags(BSONValue payload) const override {
return stdx::visit(
OverloadedVisitor{[&](BSONElement p) {
if (p.isABSONObj()) {
@@ -126,8 +130,6 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) {
TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) {
RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
- std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
-
auto expCtx = make_intrusive<ExpressionContextForTest>();
std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"};
@@ -159,29 +161,98 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) {
}
}
+TEST_F(RangePredicateRewriteTest, AggRangeRewrite_Stub) {
+ RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
+
+ {
+ auto input = fromjson(str::stream() << "{$between: [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = ExpressionConstant::create(&_expCtx, Value(true));
+
+ _predicate.isStubPayload = true;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual);
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = ExpressionConstant::create(&_expCtx, Value(true));
+
+ _predicate.isStubPayload = true;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual);
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+}
+
TEST_F(RangePredicateRewriteTest, AggRangeRewrite) {
- auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
- auto inputExpr =
- ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+ {
+ auto op = "$between";
+ auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
- auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
+ auto actual = _predicate.rewrite(inputExpr.get());
- auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+ {
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input =
+ fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
- ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
- expected->serialize(false).getDocument().toBson());
+ auto actual = _predicate.rewrite(inputExpr.get());
+
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+ }
}
TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) {
- auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
- auto inputExpr =
- ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+ {
+ auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
- auto expected = inputExpr;
+ auto expected = inputExpr;
- _predicate.payloadValid = false;
- auto actual = _predicate.rewrite(inputExpr.get());
- ASSERT(actual == nullptr);
+ _predicate.payloadValid = false;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual == nullptr);
+ }
+ {
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input =
+ fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = inputExpr;
+
+ _predicate.payloadValid = false;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual == nullptr);
+ }
+ }
}
BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) {
@@ -219,6 +290,17 @@ std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx,
return std::make_unique<ExpressionBetween>(expCtx, std::move(children));
}
+std::unique_ptr<Expression> generateBetweenWithFFP(
+ ExpressionContext* expCtx, ExpressionCompare::CmpOp op, StringData path, int lb, int ub) {
+ auto ffp = Value(generateFFP(path, lb, ub, 0, 255).firstElement());
+ auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp);
+ auto fieldpath = ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState);
+ std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath),
+ std::move(ffpExpr)};
+ return std::make_unique<ExpressionCompare>(expCtx, op, std::move(children));
+}
+
TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
_mock.setForceEncryptedCollScanForTest();
auto expected = fromjson(R"({
@@ -273,9 +355,6 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) {
_mock.setForceEncryptedCollScanForTest();
- auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35);
- auto result = _predicate.rewrite(input.get());
- ASSERT(result);
auto expected = fromjson(R"({
"$_internalFleBetween": {
"field": "$age",
@@ -309,7 +388,40 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) {
}
}
})");
- ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ {
+ auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result);
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ }
+ {
+ auto ops = {ExpressionCompare::GT,
+ ExpressionCompare::GTE,
+ ExpressionCompare::LT,
+ ExpressionCompare::LTE};
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result);
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ }
+ }
+}
+
+
+TEST_F(RangePredicateRewriteTest, UnsupportedComparisonOps) {
+ auto ops = {ExpressionCompare::CMP, ExpressionCompare::EQ, ExpressionCompare::NE};
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ }
+ _mock.setForceEncryptedCollScanForTest();
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ }
}
}; // namespace
diff --git a/src/mongo/db/query/get_executor.cpp b/src/mongo/db/query/get_executor.cpp
index 0824723b4d6..9e5d55a9d97 100644
--- a/src/mongo/db/query/get_executor.cpp
+++ b/src/mongo/db/query/get_executor.cpp
@@ -530,7 +530,9 @@ private:
* - A vector of PlanStages, representing the roots of the constructed execution trees (in the
* case when the query has multiple solutions, we may construct an execution tree for each
* solution and pick the best plan after multi-planning). Elements of this vector can never be
- * null. The size of this vector must always match the size of 'querySolutions' vector.
+ * null. The size of this vector must always be empty or match the size of 'querySolutions'
+ * vector. It will be empty in circumstances where we only construct query solutions and delay
+ * building execution trees, which is any time we are not using a cached plan.
* - A root node of the extension plan. The plan can be combined with a solution to create a
* larger plan after the winning solution is found. Can be null, meaning "no extension".
* - An optional decisionWorks value, which is populated when a solution was reconstructed from
@@ -546,9 +548,11 @@ public:
using PlanStageVector =
std::vector<std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>>;
- void emplace(std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root,
- std::unique_ptr<QuerySolution> solution) {
- _roots.push_back(std::move(root));
+ void emplace(std::unique_ptr<QuerySolution> solution) {
+ // Only allow solutions to be added, execution trees will be generated later.
+ tassert(7087100,
+ "expected execution trees to be generated after query solutions",
+ _roots.empty());
_solutions.push_back(std::move(solution));
}
@@ -560,11 +564,11 @@ public:
std::string getPlanSummary() const {
// We can report plan summary only if this result contains a single solution.
- invariant(_roots.size() == 1);
- invariant(_solutions.size() == 1);
- invariant(_roots[0].first);
+ tassert(7087101, "expected exactly one solution", _solutions.size() == 1);
+ tassert(7087102, "expected at most one execution tree", _roots.size() <= 1);
+ // We only need the query solution to build the explain summary.
auto explainer = plan_explainer_factory::make(
- _roots[0].first.get(), &_roots[0].second, _solutions[0].get());
+ nullptr /* root */, nullptr /* data */, _solutions[0].get());
return explainer->getPlanSummary();
}
@@ -572,6 +576,14 @@ public:
return std::make_pair(std::move(_roots), std::move(_solutions));
}
+ const QuerySolutionVector& solutions() const {
+ return _solutions;
+ }
+
+ const PlanStageVector& roots() const {
+ return _roots;
+ }
+
boost::optional<size_t> decisionWorks() const {
return _decisionWorks;
}
@@ -621,8 +633,14 @@ private:
* normal stage builder process.
* * We have a QuerySolutionNode tree (or multiple query solution trees), but must execute some
* custom logic in order to build the final execution tree.
+ *
+ * In most cases, the helper bypasses the final step of building the execution tree and returns only
+ * the query solution(s) if the 'DeferExecutionTreeGeneration' flag is true.
*/
-template <typename KeyType, typename PlanStageType, typename ResultType>
+template <typename KeyType,
+ typename PlanStageType,
+ typename ResultType,
+ bool DeferExecutionTreeGeneration>
class PrepareExecutionHelper {
public:
PrepareExecutionHelper(OperationContext* opCtx,
@@ -650,11 +668,8 @@ public:
auto solution = std::make_unique<QuerySolution>();
solution->setRoot(std::make_unique<EofNode>());
-
- auto root = buildExecutableTree(*solution);
-
auto result = makeResult();
- result->emplace(std::move(root), std::move(solution));
+ addSolutionToResult(result.get(), std::move(solution));
return std::move(result);
}
@@ -679,12 +694,6 @@ public:
}
auto planCacheKey = buildPlanCacheKey();
- // Fill in some opDebug information, unless it has already been filled by an outer pipeline.
- OpDebug& opDebug = CurOp::get(_opCtx)->debug();
- if (!opDebug.queryHash) {
- opDebug.queryHash = planCacheKey.queryHash();
- }
-
if (auto result = buildCachedPlan(planCacheKey)) {
return {std::move(result)};
}
@@ -715,8 +724,7 @@ public:
for (size_t i = 0; i < solutions.size(); ++i) {
if (turnIxscanIntoCount(solutions[i].get())) {
auto result = makeResult();
- auto root = buildExecutableTree(*solutions[i]);
- result->emplace(std::move(root), std::move(solutions[i]));
+ addSolutionToResult(result.get(), std::move(solutions[i]));
LOGV2_DEBUG(20925,
2,
@@ -731,9 +739,8 @@ public:
if (1 == solutions.size()) {
// Only one possible plan. Build the stages from the solution.
auto result = makeResult();
- auto root = buildExecutableTree(*solutions[0]);
solutions[0]->indexFilterApplied = _plannerParams.indexFiltersApplied;
- result->emplace(std::move(root), std::move(solutions[0]));
+ addSolutionToResult(result.get(), std::move(solutions[0]));
LOGV2_DEBUG(20926,
2,
@@ -748,6 +755,8 @@ public:
}
protected:
+ static constexpr bool ShouldDeferExecutionTreeGeneration = DeferExecutionTreeGeneration;
+
/**
* Creates a result instance to be returned to the caller holding the result of the
* prepare() call.
@@ -757,6 +766,19 @@ protected:
}
/**
+ * Adds the query solution to the result object, additionally building the corresponding
+ * execution tree if 'DeferExecutionTreeGeneration' is turned on.
+ */
+ void addSolutionToResult(ResultType* result, std::unique_ptr<QuerySolution> solution) {
+ if constexpr (!DeferExecutionTreeGeneration) {
+ auto root = buildExecutableTree(*solution);
+ result->emplace(std::move(root), std::move(solution));
+ } else {
+ result->emplace(std::move(solution));
+ }
+ }
+
+ /**
* Fills out planner parameters if not already filled.
*/
void initializePlannerParamsIfNeeded() {
@@ -820,7 +842,8 @@ protected:
class ClassicPrepareExecutionHelper final
: public PrepareExecutionHelper<PlanCacheKey,
std::unique_ptr<PlanStage>,
- ClassicPrepareExecutionResult> {
+ ClassicPrepareExecutionResult,
+ false /* DeferExecutionTreeGeneration */> {
public:
ClassicPrepareExecutionHelper(OperationContext* opCtx,
const CollectionPtr& collection,
@@ -923,6 +946,12 @@ protected:
const PlanCacheKey& planCacheKey) final {
initializePlannerParamsIfNeeded();
+ // Fill in some opDebug information, unless it has already been filled by an outer pipeline.
+ OpDebug& opDebug = CurOp::get(_opCtx)->debug();
+ if (!opDebug.queryHash) {
+ opDebug.queryHash = planCacheKey.queryHash();
+ }
+
// Before consulting the plan cache, check if we should short-circuit and construct a
// find-by-_id plan.
std::unique_ptr<ClassicPrepareExecutionResult> result = buildIdHackPlan();
@@ -1017,7 +1046,8 @@ class SlotBasedPrepareExecutionHelper final
: public PrepareExecutionHelper<
sbe::PlanCacheKey,
std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>,
- SlotBasedPrepareExecutionResult> {
+ SlotBasedPrepareExecutionResult,
+ true /* DeferExecutionTreeGeneration */> {
public:
using PrepareExecutionHelper::PrepareExecutionHelper;
@@ -1038,6 +1068,9 @@ public:
std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> buildExecutableTree(
const QuerySolution& solution) const final {
+ tassert(7087105,
+ "template indicates execution tree generation deferred",
+ !ShouldDeferExecutionTreeGeneration);
return stage_builder::buildSlotBasedExecutableTree(
_opCtx, _collections, *_cq, solution, _yieldPolicy);
}
@@ -1053,11 +1086,6 @@ protected:
if (!feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) {
return buildCachedPlanFromClassicCache();
} else {
- OpDebug& opDebug = CurOp::get(_opCtx)->debug();
- if (!opDebug.planCacheKey) {
- opDebug.planCacheKey = planCacheKey.planCacheKeyHash();
- }
-
auto&& planCache = sbe::getPlanCache(_opCtx);
auto cacheEntry = planCache.getCacheEntryIfActive(planCacheKey);
if (!cacheEntry) {
@@ -1089,10 +1117,6 @@ protected:
std::unique_ptr<SlotBasedPrepareExecutionResult> buildCachedPlanFromClassicCache() {
const auto& mainColl = getMainCollection();
auto planCacheKey = plan_cache_key_factory::make<PlanCacheKey>(*_cq, mainColl);
- OpDebug& opDebug = CurOp::get(_opCtx)->debug();
- if (!opDebug.planCacheKey) {
- opDebug.planCacheKey = planCacheKey.planCacheKeyHash();
- }
// Try to look up a cached solution for the query.
if (auto cs = CollectionQueryInfo::get(mainColl).getPlanCache()->getCacheEntryIfActive(
planCacheKey)) {
@@ -1109,9 +1133,7 @@ protected:
}
auto result = makeResult();
- auto&& execTree = buildExecutableTree(*querySolution);
-
- result->emplace(std::move(execTree), std::move(querySolution));
+ addSolutionToResult(result.get(), std::move(querySolution));
result->setDecisionWorks(cs->decisionWorks);
result->setRecoveredFromPlanCache(true);
return result;
@@ -1134,8 +1156,7 @@ protected:
for (size_t ix = 0; ix < solutions.size(); ++ix) {
solutions[ix]->indexFilterApplied = _plannerParams.indexFiltersApplied;
- auto execTree = buildExecutableTree(*solutions[ix]);
- result->emplace(std::move(execTree), std::move(solutions[ix]));
+ addSolutionToResult(result.get(), std::move(solutions[ix]));
}
return result;
}
@@ -1256,29 +1277,40 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe
OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
std::unique_ptr<CanonicalQuery> cq,
- PlanYieldPolicy::YieldPolicy requestedYieldPolicy,
- const QueryPlannerParams& plannerParams) {
- // Mark that this query uses the SBE engine, unless this has already been set.
+ std::unique_ptr<PlanYieldPolicySBE> yieldPolicy,
+ const QueryPlannerParams& plannerParams,
+ std::unique_ptr<SlotBasedPrepareExecutionResult> planningResult) {
+ // Now that we know what executor we are going to use, fill in some opDebug information, unless
+ // it has already been filled by an outer pipeline.
OpDebug& opDebug = CurOp::get(opCtx)->debug();
if (!opDebug.classicEngineUsed) {
opDebug.classicEngineUsed = false;
}
-
- const auto mainColl = &collections.getMainCollection();
+ if (collections.getMainCollection()) {
+ auto planCacheKey = plan_cache_key_factory::make(*cq, collections);
+ if (!opDebug.queryHash) {
+ opDebug.queryHash = planCacheKey.queryHash();
+ }
+ if (!opDebug.planCacheKey && shouldCacheQuery(*cq)) {
+ opDebug.planCacheKey = planCacheKey.planCacheKeyHash();
+ }
+ }
// Analyze the provided query and build the list of candidate plans for it.
auto nss = cq->nss();
- auto yieldPolicy = makeSbeYieldPolicy(opCtx, requestedYieldPolicy, mainColl, nss);
- SlotBasedPrepareExecutionHelper helper{
- opCtx, collections, cq.get(), yieldPolicy.get(), plannerParams.options};
- auto planningResultWithStatus = helper.prepare();
- if (!planningResultWithStatus.isOK()) {
- return planningResultWithStatus.getStatus();
- }
- auto&& planningResult = planningResultWithStatus.getValue();
auto&& [roots, solutions] = planningResult->extractResultData();
+ invariant(roots.empty() || roots.size() == solutions.size());
+ if (roots.empty()) {
+ // We might have execution trees already if we pulled the plan from the cache. If not, we
+ // need to generate one for each solution.
+ for (const auto& solution : solutions) {
+ roots.emplace_back(stage_builder::buildSlotBasedExecutableTree(
+ opCtx, collections, *cq, *solution, yieldPolicy.get()));
+ }
+ }
+
// When query requires sub-planning, we may not get any executable plans.
const auto planStageData = roots.empty()
? boost::none
@@ -1318,7 +1350,8 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe
*cq,
std::move(solutions[0]),
fillOutSecondaryCollectionsInformation(opCtx, collections, cq.get()));
- roots[0] = helper.buildExecutableTree(*(solutions[0]));
+ roots[0] = stage_builder::buildSlotBasedExecutableTree(
+ opCtx, collections, *cq, *(solutions[0]), yieldPolicy.get());
}
plan_cache_util::updatePlanCache(opCtx, collections, *cq, *solutions[0], *root, data);
@@ -1338,13 +1371,99 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe
std::move(yieldPolicy),
planningResult->isRecoveredFromPlanCache());
}
+
+/**
+ * Checks if the result of query planning is SBE compatible.
+ */
+bool isPlanSbeCompatible(const SlotBasedPrepareExecutionResult& planningResult) {
+ const auto& solutions = planningResult.solutions();
+ if (solutions.empty()) {
+ // Query needs subplanning (plans are generated later, we don't have access yet).
+ invariant(planningResult.needsSubplanning());
+ }
+
+ // Check that the query solution is SBE compatible.
+
+ return true;
+}
+
+/**
+ * Attempts to create a slot-based executor for the query, if the query plan is eligible for SBE
+ * execution. This function has three possible return values:
+ *
+ * 1. A plan executor. This is in the case where the query is SBE eligible and encounters no errors
+ * in executor creation. This result is to be expected in the majority of cases.
+ * 2. A non-OK status. This is when errors are encountered during executor creation.
+ * 3. The canonical query. This is to return ownership of the 'canonicalQuery' argument in the case
+ * where the query plan is not eligible for SBE execution but it is not an error case.
+ */
+StatusWith<stdx::variant<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>,
+ std::unique_ptr<CanonicalQuery>>>
+attemptToGetSlotBasedExecutor(
+ OperationContext* opCtx,
+ const MultipleCollectionAccessor& collections,
+ std::unique_ptr<CanonicalQuery> canonicalQuery,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
+ PlanYieldPolicy::YieldPolicy yieldPolicy,
+ const QueryPlannerParams& plannerParams) {
+ if (extractAndAttachPipelineStages) {
+ // Push down applicable pipeline stages and attach to the query, but don't remove from
+ // the high-level pipeline object until we know for sure we will execute with SBE.
+ extractAndAttachPipelineStages(canonicalQuery.get(), true /* attachOnly */);
+ }
+
+ // Use SBE if we find any $group/$lookup stages eligible for execution in SBE or if SBE
+ // is fully enabled. Otherwise, fallback to the classic engine.
+ if (canonicalQuery->pipeline().empty() &&
+ !feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) {
+ canonicalQuery->setSbeCompatible(false);
+ } else {
+ // Construct the SBE query solution - this will be our final decision stage to determine
+ // whether SBE should be used.
+ auto sbeYieldPolicy = makeSbeYieldPolicy(
+ opCtx, yieldPolicy, &collections.getMainCollection(), canonicalQuery->nss());
+
+ SlotBasedPrepareExecutionHelper helper{
+ opCtx, collections, canonicalQuery.get(), sbeYieldPolicy.get(), plannerParams.options};
+ auto planningResultWithStatus = helper.prepare();
+ if (planningResultWithStatus.isOK() &&
+ isPlanSbeCompatible(*planningResultWithStatus.getValue())) {
+ if (extractAndAttachPipelineStages) {
+ // We know now that we will use SBE, so we need to remove the pushed-down stages
+ // from the original pipeline object.
+ extractAndAttachPipelineStages(canonicalQuery.get(), false /* attachOnly */);
+ }
+ auto statusWithExecutor =
+ getSlotBasedExecutor(opCtx,
+ collections,
+ std::move(canonicalQuery),
+ std::move(sbeYieldPolicy),
+ plannerParams,
+ std::move(planningResultWithStatus.getValue()));
+ if (statusWithExecutor.isOK()) {
+ return std::move(statusWithExecutor.getValue());
+ } else {
+ return statusWithExecutor.getStatus();
+ }
+ }
+
+ // Query plan was not SBE compatible - reset any fields that may have been modified, and
+ // fall back to classic engine.
+ canonicalQuery->setSbeCompatible(false);
+ canonicalQuery->setPipeline({});
+ }
+
+ // Return ownership of the canonical query to the caller.
+ return std::move(canonicalQuery);
+}
+
} // namespace
StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
PlanYieldPolicy::YieldPolicy yieldPolicy,
const QueryPlannerParams& plannerParams) {
invariant(canonicalQuery);
@@ -1358,17 +1477,27 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
// Use SBE if 'canonicalQuery' is SBE compatible.
if (!canonicalQuery->getForceClassicEngine() && canonicalQuery->isSbeCompatible()) {
- if (extractAndAttachPipelineStages) {
- extractAndAttachPipelineStages(canonicalQuery.get());
- }
- // Use SBE if we find any $group/$lookup stages eligible for execution in SBE or if SBE
- // is fully enabled. Otherwise, fallback to the classic engine.
- if (canonicalQuery->pipeline().empty() &&
- !feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) {
- canonicalQuery->setSbeCompatible(false);
+ auto statusWithExecutor = attemptToGetSlotBasedExecutor(opCtx,
+ collections,
+ std::move(canonicalQuery),
+ extractAndAttachPipelineStages,
+ yieldPolicy,
+ plannerParams);
+ if (!statusWithExecutor.isOK()) {
+ return statusWithExecutor.getStatus();
+ }
+ auto& maybeExecutor = statusWithExecutor.getValue();
+ if (stdx::holds_alternative<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>>(
+ maybeExecutor)) {
+ return std::move(
+ stdx::get<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>>(maybeExecutor));
} else {
- return getSlotBasedExecutor(
- opCtx, collections, std::move(canonicalQuery), yieldPolicy, plannerParams);
+ // The query is not eligible for SBE execution - reclaim the canonical query and fall
+ // back to classic.
+ tassert(7087103,
+ "return value must contain canonical query if not executor",
+ stdx::holds_alternative<std::unique_ptr<CanonicalQuery>>(maybeExecutor));
+ canonicalQuery = std::move(stdx::get<std::unique_ptr<CanonicalQuery>>(maybeExecutor));
}
}
@@ -1380,7 +1509,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
OperationContext* opCtx,
const CollectionPtr* collection,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
PlanYieldPolicy::YieldPolicy yieldPolicy,
size_t plannerOptions) {
MultipleCollectionAccessor multi{collection};
@@ -1400,7 +1529,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind
OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
bool permitYield,
QueryPlannerParams plannerParams) {
auto yieldPolicy = (permitYield && !opCtx->inMultiDocumentTransaction())
@@ -1423,7 +1552,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind
OperationContext* opCtx,
const CollectionPtr* coll,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
bool permitYield,
size_t plannerOptions) {
MultipleCollectionAccessor multi{*coll};
@@ -1612,6 +1741,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorDele
// This is the regular path for when we have a CanonicalQuery.
std::unique_ptr<CanonicalQuery> cq(parsedDelete->releaseParsedQuery());
+ uassert(ErrorCodes::InternalErrorNotSupported,
+ "delete command is not eligible for bonsai",
+ !isEligibleForBonsai(*cq, opCtx, collection));
+
// Transfer the explain verbosity level into the expression context.
cq->getExpCtx()->explain = verbosity;
@@ -1804,6 +1937,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorUpda
// This is the regular path for when we have a CanonicalQuery.
std::unique_ptr<CanonicalQuery> cq(parsedUpdate->releaseParsedQuery());
+ uassert(ErrorCodes::InternalErrorNotSupported,
+ "update command is not eligible for bonsai",
+ !isEligibleForBonsai(*cq, opCtx, collection));
+
std::unique_ptr<projection_ast::Projection> projection;
if (!request->getProj().isEmpty()) {
invariant(request->shouldReturnAnyDocs());
@@ -2129,6 +2266,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorCoun
const auto skip = request.getSkip().value_or(0);
const auto limit = request.getLimit().value_or(0);
+ uassert(ErrorCodes::InternalErrorNotSupported,
+ "count command is not eligible for bonsai",
+ !isEligibleForBonsai(*cq, opCtx, collection));
+
if (!collection) {
// Treat collections that do not exist as empty collections. Note that the explain reporting
// machinery always assumes that the root stage for a count operation is a CountStage, so in
@@ -2609,6 +2750,11 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorDist
? PlanYieldPolicy::YieldPolicy::INTERRUPT_ONLY
: PlanYieldPolicy::YieldPolicy::YIELD_AUTO;
+ // Assert that not eligible for bonsai
+ uassert(ErrorCodes::InternalErrorNotSupported,
+ "distinct command is not eligible for bonsai",
+ !isEligibleForBonsai(*parsedDistinct->getQuery(), opCtx, collection));
+
if (!collection) {
// Treat collections that do not exist as empty collections.
return plan_executor_factory::make(parsedDistinct->releaseQuery(),
diff --git a/src/mongo/db/query/get_executor.h b/src/mongo/db/query/get_executor.h
index 75ab5475f2f..a1e24952b76 100644
--- a/src/mongo/db/query/get_executor.h
+++ b/src/mongo/db/query/get_executor.h
@@ -155,7 +155,9 @@ bool shouldWaitForOplogVisibility(OperationContext* opCtx,
* If the caller provides a 'extractAndAttachPipelineStages' function and the query is eligible for
* pushdown into the find layer this function will be invoked to extract pipeline stages and
* attach them to the provided 'CanonicalQuery'. This function should capture the Pipeline that
- * stages should be extracted from.
+ * stages should be extracted from. If the boolean 'attachOnly' argument is true, it will only find
+ * and attach the applicable stages to the query. If it is false, it will remove the extracted
+ * stages from the pipeline.
*
* Note that the first overload takes a 'MultipleCollectionAccessor' and can construct a
* PlanExecutor over multiple collections, while the second overload takes a single 'CollectionPtr'
@@ -165,7 +167,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
PlanYieldPolicy::YieldPolicy yieldPolicy,
const QueryPlannerParams& plannerOptions);
@@ -173,7 +175,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
OperationContext* opCtx,
const CollectionPtr* collection,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
PlanYieldPolicy::YieldPolicy yieldPolicy,
size_t plannerOptions = 0);
@@ -190,7 +192,9 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor(
* If the caller provides a 'extractAndAttachPipelineStages' function and the query is eligible for
* pushdown into the find layer this function will be invoked to extract pipeline stages and
* attach them to the provided 'CanonicalQuery'. This function should capture the Pipeline that
- * stages should be extracted from.
+ * stages should be extracted from. If the boolean 'attachOnly' argument is true, it will only find
+ * and attach the applicable stages to the query. If it is false, it will remove the extracted
+ * stages from the pipeline.
*
* Note that the first overload takes a 'MultipleCollectionAccessor' and can construct a
* PlanExecutor over multiple collections, while the second overload takes a single
@@ -200,7 +204,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind
OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
bool permitYield = false,
QueryPlannerParams plannerOptions = QueryPlannerParams{});
@@ -208,7 +212,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind
OperationContext* opCtx,
const CollectionPtr* collection,
std::unique_ptr<CanonicalQuery> canonicalQuery,
- std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages,
+ std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages,
bool permitYield = false,
size_t plannerOptions = QueryPlannerParams::DEFAULT);
diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript
index a128b8eecc1..70ce6442e3a 100644
--- a/src/mongo/db/query/optimizer/SConscript
+++ b/src/mongo/db/query/optimizer/SConscript
@@ -91,17 +91,6 @@ env.Library(
],
)
-# Default costing module.
-env.Library(
- target="optimizer_default_costing",
- source=[
- "cascades/cost_derivation.cpp",
- ],
- LIBDEPS=[
- "optimizer_memo",
- ],
-)
-
# Main optimizer target.
env.Library(
target="optimizer",
@@ -112,7 +101,6 @@ env.Library(
LIBDEPS=[
"optimizer_cascades",
"optimizer_default_ce",
- "optimizer_default_costing",
"optimizer_rewrites",
],
)
@@ -126,6 +114,7 @@ env.Library(
LIBDEPS=[
# We do not depend on the entire pipeline target.
"$BUILD_DIR/mongo/db/pipeline/abt_utils",
+ "$BUILD_DIR/mongo/db/query/cost_model/query_cost_model",
"$BUILD_DIR/mongo/db/query/optimizer/optimizer",
"$BUILD_DIR/mongo/unittest/unittest",
],
diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp
deleted file mode 100644
index f58380b50b3..00000000000
--- a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp
+++ /dev/null
@@ -1,446 +0,0 @@
-/**
- * Copyright (C) 2022-present MongoDB, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the Server Side Public License, version 1,
- * as published by MongoDB, Inc.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * Server Side Public License for more details.
- *
- * You should have received a copy of the Server Side Public License
- * along with this program. If not, see
- * <http://www.mongodb.com/licensing/server-side-public-license>.
- *
- * As a special exception, the copyright holders give permission to link the
- * code of portions of this program with the OpenSSL library under certain
- * conditions as described in each individual source file and distribute
- * linked combinations including the program with the OpenSSL library. You
- * must comply with the Server Side Public License in all respects for
- * all of the code used other than as permitted herein. If you modify file(s)
- * with this exception, you may extend this exception to your version of the
- * file(s), but you are not obligated to do so. If you do not wish to do so,
- * delete this exception statement from your version. If you delete this
- * exception statement from all source files in the program, then also delete
- * it in the license file.
- */
-
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
-#include "mongo/db/query/optimizer/defs.h"
-
-namespace mongo::optimizer::cascades {
-
-using namespace properties;
-
-struct CostAndCEInternal {
- CostAndCEInternal(double cost, CEType ce) : _cost(cost), _ce(ce) {
- uassert(8423334, "Invalid cost.", !std::isnan(cost) && cost >= 0.0);
- uassert(8423332, "Invalid cardinality", std::isfinite(ce) && ce >= 0.0);
- }
- double _cost;
- CEType _ce;
-};
-
-class CostDerivation {
- // These cost should reflect estimated aggregated execution time in milliseconds.
- static constexpr double ms = 1.0e-3;
-
- // Startup cost of an operator. This is the minimal cost of an operator since it is
- // present even if it doesn't process any input.
- // TODO: calibrate the cost individually for each operator
- static constexpr double kStartupCost = 0.000001;
-
- // TODO: collection scan should depend on the width of the doc.
- // TODO: the actual measured cost is (0.4 * ms), however we increase it here because currently
- // it is not possible to estimate the cost of a collection scan vs a full index scan.
- static constexpr double kScanIncrementalCost = 0.6 * ms;
-
- // TODO: cost(N fields) ~ (0.55 + 0.025 * N)
- static constexpr double kIndexScanIncrementalCost = 0.5 * ms;
-
- // TODO: cost(N fields) ~ 0.7 + 0.19 * N
- static constexpr double kSeekCost = 2.0 * ms;
-
- // TODO: take the expression into account.
- // cost(N conditions) = 0.2 + N * ???
- static constexpr double kFilterIncrementalCost = 0.2 * ms;
- // TODO: the cost of projection depends on number of fields: cost(N fields) ~ 0.1 + 0.2 * N
- static constexpr double kEvalIncrementalCost = 2.0 * ms;
-
- // TODO: cost(N fields) ~ 0.04 + 0.03*(N^2)
- static constexpr double kGroupByIncrementalCost = 0.07 * ms;
- static constexpr double kUnwindIncrementalCost = 0.03 * ms; // TODO: not yet calibrated
- // TODO: not yet calibrated, should be at least as expensive as a filter
- static constexpr double kBinaryJoinIncrementalCost = 0.2 * ms;
- static constexpr double kHashJoinIncrementalCost = 0.05 * ms; // TODO: not yet calibrated
- static constexpr double kMergeJoinIncrementalCost = 0.02 * ms; // TODO: not yet calibrated
-
- static constexpr double kUniqueIncrementalCost = 0.7 * ms;
-
- // TODO: implement collation cost that depends on number and size of sorted fields
- // Based on a mix of int and str(64) fields:
- // 1 sort field: sort_cost(N) = 1.0/10 * N * log(N)
- // 5 sort fields: sort_cost(N) = 2.5/10 * N * log(N)
- // 10 sort fields: sort_cost(N) = 3.0/10 * N * log(N)
- // field_cost_coeff(F) ~ 0.75 + 0.2 * F
- static constexpr double kCollationIncrementalCost = 2.5 * ms; // 5 fields avg
- static constexpr double kCollationWithLimitIncrementalCost =
- 1.0 * ms; // TODO: not yet calibrated
-
- static constexpr double kUnionIncrementalCost = 0.02 * ms;
-
- static constexpr double kExchangeIncrementalCost = 0.1 * ms; // TODO: not yet calibrated
-
-public:
- CostAndCEInternal operator()(const ABT& /*n*/, const PhysicalScanNode& /*node*/) {
- // Default estimate for scan.
- const double collectionScanCost =
- kStartupCost + kScanIncrementalCost * _cardinalityEstimate;
- return {collectionScanCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const CoScanNode& /*node*/) {
- // Assumed to be free.
- return {kStartupCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const IndexScanNode& node) {
- const double indexScanCost =
- kStartupCost + kIndexScanIncrementalCost * _cardinalityEstimate;
- return {indexScanCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const SeekNode& /*node*/) {
- // SeekNode should deliver one result via cardinality estimate override.
- // TODO: consider using node.getProjectionMap()._fieldProjections.size() to make the cost
- // dependent on the size of the projection
- const double seekCost = kStartupCost + kSeekCost * _cardinalityEstimate;
- return {seekCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& node) {
- const LogicalProps& childLogicalProps = _memo.getLogicalProps(node.getGroupId());
- // Notice that unlike all physical nodes, this logical node takes it cardinality directly
- // from the memo group logical property, ignoring _cardinalityEstimate.
- CEType baseCE = getPropertyConst<CardinalityEstimate>(childLogicalProps).getEstimate();
-
- if (hasProperty<IndexingRequirement>(_physProps)) {
- const auto& indexingReq = getPropertyConst<IndexingRequirement>(_physProps);
- if (indexingReq.getIndexReqTarget() == IndexReqTarget::Seek) {
- // If we are performing a seek, normalize against the scan group cardinality.
- const GroupIdType scanGroupId =
- getPropertyConst<IndexingAvailability>(childLogicalProps).getScanGroupId();
- if (scanGroupId == node.getGroupId()) {
- baseCE = 1.0;
- } else {
- const CEType scanGroupCE =
- getPropertyConst<CardinalityEstimate>(_memo.getLogicalProps(scanGroupId))
- .getEstimate();
- if (scanGroupCE > 0.0) {
- baseCE /= scanGroupCE;
- }
- }
- }
- }
-
- return {0.0, getAdjustedCE(baseCE, _physProps)};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const MemoPhysicalDelegatorNode& /*node*/) {
- uasserted(6624040, "Should not be costing physical delegator nodes.");
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const FilterNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- double filterCost = childResult._cost;
- if (!isTrivialExpr<EvalFilter>(node.getFilter())) {
- // Non-trivial filter.
- filterCost += kStartupCost + kFilterIncrementalCost * childResult._ce;
- }
- return {filterCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const EvaluationNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- double evalCost = childResult._cost;
- if (!isTrivialExpr<EvalPath>(node.getProjection())) {
- // Non-trivial projection.
- evalCost += kStartupCost + kEvalIncrementalCost * _cardinalityEstimate;
- }
- return {evalCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const BinaryJoinNode& node) {
- CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
- CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
- const double joinCost = kStartupCost +
- kBinaryJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
- leftChildResult._cost + rightChildResult._cost;
- return {joinCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const HashJoinNode& node) {
- CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
- CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
-
- // TODO: distinguish build side and probe side.
- const double hashJoinCost = kStartupCost +
- kHashJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
- leftChildResult._cost + rightChildResult._cost;
- return {hashJoinCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const MergeJoinNode& node) {
- CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0);
- CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1);
-
- const double mergeJoinCost = kStartupCost +
- kMergeJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) +
- leftChildResult._cost + rightChildResult._cost;
-
- return {mergeJoinCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const UnionNode& node) {
- const ABTVector& children = node.nodes();
- // UnionNode with one child is optimized away before lowering, therefore
- // its cost is the cost of its child.
- if (children.size() == 1) {
- CostAndCEInternal childResult = deriveChild(children[0], 0);
- return {childResult._cost, _cardinalityEstimate};
- }
-
- double totalCost = kStartupCost;
- // The cost is the sum of the costs of its children and the cost to union each child.
- for (size_t childIdx = 0; childIdx < children.size(); childIdx++) {
- CostAndCEInternal childResult = deriveChild(children[childIdx], childIdx);
- const double childCost =
- childResult._cost + (childIdx > 0 ? kUnionIncrementalCost * childResult._ce : 0);
- totalCost += childCost;
- }
- return {totalCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const GroupByNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- double groupByCost = kStartupCost;
-
- // TODO: for now pretend global group by is free.
- if (node.getType() == GroupNodeType::Global) {
- groupByCost += childResult._cost;
- } else {
- // TODO: consider RepetitionEstimate since this is a stateful operation.
- groupByCost += kGroupByIncrementalCost * childResult._ce + childResult._cost;
- }
- return {groupByCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const UnwindNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- // Unwind probably depends mostly on its output size.
- const double unwindCost = kUnwindIncrementalCost * _cardinalityEstimate + childResult._cost;
- return {unwindCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const UniqueNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- const double uniqueCost =
- kStartupCost + kUniqueIncrementalCost * childResult._ce + childResult._cost;
- return {uniqueCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const CollationNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- // TODO: consider RepetitionEstimate since this is a stateful operation.
-
- double logFactor = childResult._ce;
- double incrConst = kCollationIncrementalCost;
- if (hasProperty<LimitSkipRequirement>(_physProps)) {
- if (auto limit = getPropertyConst<LimitSkipRequirement>(_physProps).getAbsoluteLimit();
- limit < logFactor) {
- logFactor = limit;
- incrConst = kCollationWithLimitIncrementalCost;
- }
- }
-
- // Notice that log2(x) < 0 for any x < 1, and log2(1) = 0. Generally it makes sense that
- // there is no cost to sort 1 document, so the only cost left is the startup cost.
- const double sortCost = kStartupCost + childResult._cost +
- ((logFactor <= 1.0)
- ? 0.0
- // TODO: The cost formula below is based on 1 field, mix of int and str. Instead we
- // have to take into account the number and size of sorted fields.
- : incrConst * childResult._ce * std::log2(logFactor));
- return {sortCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const LimitSkipNode& node) {
- // Assumed to be free.
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- const double limitCost = kStartupCost + childResult._cost;
- return {limitCost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const ExchangeNode& node) {
- CostAndCEInternal childResult = deriveChild(node.getChild(), 0);
- double localCost = kStartupCost + kExchangeIncrementalCost * _cardinalityEstimate;
-
- switch (node.getProperty().getDistributionAndProjections()._type) {
- case DistributionType::Replicated:
- localCost *= 2.0;
- break;
-
- case DistributionType::HashPartitioning:
- case DistributionType::RangePartitioning:
- localCost *= 1.1;
- break;
-
- default:
- break;
- }
-
- return {localCost + childResult._cost, _cardinalityEstimate};
- }
-
- CostAndCEInternal operator()(const ABT& /*n*/, const RootNode& node) {
- return deriveChild(node.getChild(), 0);
- }
-
- /**
- * Other ABT types.
- */
- template <typename T, typename... Ts>
- CostAndCEInternal operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) {
- static_assert(!canBePhysicalNode<T>(), "Physical node must implement its cost derivation.");
- return {0.0, 0.0};
- }
-
- static CostAndCEInternal derive(const Metadata& metadata,
- const Memo& memo,
- const PhysProps& physProps,
- const ABT::reference_type physNodeRef,
- const ChildPropsType& childProps,
- const NodeCEMap& nodeCEMap) {
- CostAndCEInternal result =
- deriveInternal(metadata, memo, physProps, physNodeRef, childProps, nodeCEMap);
-
- switch (getPropertyConst<DistributionRequirement>(physProps)
- .getDistributionAndProjections()
- ._type) {
- case DistributionType::Centralized:
- case DistributionType::Replicated:
- break;
-
- case DistributionType::RoundRobin:
- case DistributionType::HashPartitioning:
- case DistributionType::RangePartitioning:
- case DistributionType::UnknownPartitioning:
- result._cost /= metadata._numberOfPartitions;
- break;
-
- default:
- MONGO_UNREACHABLE;
- }
-
- return result;
- }
-
-private:
- CostDerivation(const Metadata& metadata,
- const Memo& memo,
- const CEType ce,
- const PhysProps& physProps,
- const ChildPropsType& childProps,
- const NodeCEMap& nodeCEMap)
- : _metadata(metadata),
- _memo(memo),
- _physProps(physProps),
- _cardinalityEstimate(getAdjustedCE(ce, _physProps)),
- _childProps(childProps),
- _nodeCEMap(nodeCEMap) {}
-
- template <class T>
- static bool isTrivialExpr(const ABT& n) {
- if (n.is<Constant>() || n.is<Variable>()) {
- return true;
- }
- if (const auto* ptr = n.cast<T>(); ptr != nullptr &&
- ptr->getPath().template is<PathIdentity>() && isTrivialExpr<T>(ptr->getInput())) {
- return true;
- }
- return false;
- }
-
- static CostAndCEInternal deriveInternal(const Metadata& metadata,
- const Memo& memo,
- const PhysProps& physProps,
- const ABT::reference_type physNodeRef,
- const ChildPropsType& childProps,
- const NodeCEMap& nodeCEMap) {
- auto it = nodeCEMap.find(physNodeRef.cast<Node>());
- bool found = (it != nodeCEMap.cend());
- uassert(8423330,
- "Only MemoLogicalDelegatorNode can be missing from nodeCEMap.",
- found || physNodeRef.is<MemoLogicalDelegatorNode>());
- const CEType ce = (found ? it->second : 0.0);
-
- CostDerivation instance(metadata, memo, ce, physProps, childProps, nodeCEMap);
- CostAndCEInternal costCEestimates = physNodeRef.visit(instance);
- return costCEestimates;
- }
-
- CostAndCEInternal deriveChild(const ABT& child, const size_t childIndex) {
- PhysProps physProps = _childProps.empty() ? _physProps : _childProps.at(childIndex).second;
- return deriveInternal(_metadata, _memo, physProps, child.ref(), {}, _nodeCEMap);
- }
-
- static CEType getAdjustedCE(CEType baseCE, const PhysProps& physProps) {
- CEType result = baseCE;
-
- // First: correct for un-enforced limit.
- if (hasProperty<LimitSkipRequirement>(physProps)) {
- const auto limit = getPropertyConst<LimitSkipRequirement>(physProps).getAbsoluteLimit();
- if (result > limit) {
- result = limit;
- }
- }
-
- // Second: correct for enforced limit.
- if (hasProperty<LimitEstimate>(physProps)) {
- const auto limit = getPropertyConst<LimitEstimate>(physProps).getEstimate();
- if (result > limit) {
- result = limit;
- }
- }
-
- // Third: correct for repetition.
- if (hasProperty<RepetitionEstimate>(physProps)) {
- result *= getPropertyConst<RepetitionEstimate>(physProps).getEstimate();
- }
-
- return result;
- }
-
- // We don't own this.
- const Metadata& _metadata;
- const Memo& _memo;
- const PhysProps& _physProps;
- const CEType _cardinalityEstimate;
- const ChildPropsType& _childProps;
- const NodeCEMap& _nodeCEMap;
-};
-
-CostAndCE DefaultCosting::deriveCost(const Metadata& metadata,
- const Memo& memo,
- const PhysProps& physProps,
- const ABT::reference_type physNodeRef,
- const ChildPropsType& childProps,
- const NodeCEMap& nodeCEMap) const {
- const CostAndCEInternal result =
- CostDerivation::derive(metadata, memo, physProps, physNodeRef, childProps, nodeCEMap);
- return {CostType::fromDouble(result._cost), result._ce};
-}
-
-} // namespace mongo::optimizer::cascades
diff --git a/src/mongo/db/query/optimizer/cascades/implementers.cpp b/src/mongo/db/query/optimizer/cascades/implementers.cpp
index 6039ac3cb43..835045a2611 100644
--- a/src/mongo/db/query/optimizer/cascades/implementers.cpp
+++ b/src/mongo/db/query/optimizer/cascades/implementers.cpp
@@ -189,8 +189,7 @@ public:
if (node.getArraySize() == 0) {
nodeCEMap.emplace(physNode.cast<Node>(), 0.0);
- physNode =
- make<LimitSkipNode>(properties::LimitSkipRequirement{0, 0}, std::move(physNode));
+ physNode = make<LimitSkipNode>(LimitSkipRequirement{0, 0}, std::move(physNode));
nodeCEMap.emplace(physNode.cast<Node>(), 0.0);
for (const ProjectionName& boundProjName : node.binder().names()) {
@@ -208,8 +207,7 @@ public:
} else {
nodeCEMap.emplace(physNode.cast<Node>(), 1.0);
- physNode =
- make<LimitSkipNode>(properties::LimitSkipRequirement{1, 0}, std::move(physNode));
+ physNode = make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physNode));
nodeCEMap.emplace(physNode.cast<Node>(), 1.0);
const ProjectionName valueScanProj = _prefixId.getNextId("valueScan");
@@ -710,7 +708,21 @@ public:
return;
}
const bool isIndex = indexReqTarget == IndexReqTarget::Index;
- if (isIndex && (!node.hasLeftIntervals() || !node.hasRightIntervals())) {
+
+ const GroupIdType leftGroupId =
+ node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+ const GroupIdType rightGroupId =
+ node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+
+ const LogicalProps& leftLogicalProps = _memo.getLogicalProps(leftGroupId);
+ const LogicalProps& rightLogicalProps = _memo.getLogicalProps(rightGroupId);
+
+ const bool hasProperIntervalLeft =
+ getPropertyConst<IndexingAvailability>(leftLogicalProps).hasProperInterval();
+ const bool hasProperIntervalRight =
+ getPropertyConst<IndexingAvailability>(rightLogicalProps).hasProperInterval();
+
+ if (isIndex && (!hasProperIntervalLeft || !hasProperIntervalRight)) {
// We need to have proper intervals on both sides.
return;
}
@@ -729,14 +741,6 @@ public:
}
}
- const GroupIdType leftGroupId =
- node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
- const GroupIdType rightGroupId =
- node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
-
- const LogicalProps& leftLogicalProps = _memo.getLogicalProps(leftGroupId);
- const LogicalProps& rightLogicalProps = _memo.getLogicalProps(rightGroupId);
-
const CEType intersectedCE =
getPropertyConst<CardinalityEstimate>(_logicalProps).getEstimate();
const CEType leftCE = getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate();
@@ -1660,9 +1664,9 @@ void addImplementers(const Metadata& metadata,
const QueryHints& hints,
const RIDProjectionsMap& ridProjections,
PrefixId& prefixId,
- const properties::PhysProps& physProps,
+ const PhysProps& physProps,
PhysQueueAndImplPos& queue,
- const properties::LogicalProps& logicalProps,
+ const LogicalProps& logicalProps,
const OrderPreservingABTSet& logicalNodes,
const PathToIntervalFn& pathToInterval) {
ImplementationVisitor visitor(metadata,
diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp
index c2ef5529e79..29d4ccc22ec 100644
--- a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp
+++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp
@@ -222,8 +222,7 @@ public:
indexingAvailability.setEqPredsOnly(computeEqPredsOnly(node.getReqMap()));
}
- auto& satisfiedPartialIndexes =
- getProperty<IndexingAvailability>(result).getSatisfiedPartialIndexes();
+ auto& satisfiedPartialIndexes = indexingAvailability.getSatisfiedPartialIndexes();
for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) {
if (!indexDef.getPartialReqMap().empty()) {
auto intersection = node.getReqMap();
@@ -237,6 +236,8 @@ public:
}
}
+ indexingAvailability.setHasProperInterval(hasProperIntervals(node.getReqMap()));
+
return maybeUpdateNodePropsMap(node, std::move(result));
}
diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp
index 421d03f73a1..0c9a923b2ed 100644
--- a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp
+++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp
@@ -941,8 +941,6 @@ struct SplitRequirementsResult {
PartialSchemaRequirements _rightReqs;
bool _hasFieldCoverage = true;
- bool _hasLeftIntervals = false;
- bool _hasRightIntervals = false;
};
/**
@@ -981,7 +979,6 @@ static SplitRequirementsResult splitRequirements(
size_t index = 0;
for (const auto& [key, req] : reqMap) {
- const bool fullyOpenInterval = isFullyOpen.at(index);
if (((1ull << index) & mask) != 0) {
bool addedToLeft = false;
@@ -1005,7 +1002,7 @@ static SplitRequirementsResult splitRequirements(
// We cannot return index values if our interval can possibly contain Null. Instead,
// we remove the output binding for the left side, and return the value from the
// right (seek) side.
- if (!fullyOpenInterval) {
+ if (!isFullyOpen.at(index)) {
addRequirement(
leftReqs, key, boost::none /*boundProjectionName*/, req.getIntervals());
addedToLeft = true;
@@ -1018,9 +1015,6 @@ static SplitRequirementsResult splitRequirements(
}
if (addedToLeft) {
- if (!fullyOpenInterval) {
- result._hasLeftIntervals = true;
- }
if (indexFieldPrefixMapForScanDef) {
if (auto pathPtr = key._path.cast<PathGet>(); pathPtr != nullptr &&
indexFieldPrefixMapForScanDef->count(pathPtr->name()) == 0) {
@@ -1033,9 +1027,6 @@ static SplitRequirementsResult splitRequirements(
}
} else if (isIndex || !req.getIsPerfOnly()) {
addRequirement(rightReqs, key, req.getBoundProjectionName(), req.getIntervals());
- if (!fullyOpenInterval) {
- result._hasRightIntervals = true;
- }
}
index++;
}
@@ -1141,10 +1132,13 @@ struct ExploreConvert<SargableNode> {
continue;
}
- if (isIndex && (!splitResult._hasLeftIntervals || !splitResult._hasRightIntervals)) {
- // Reject. Must have at least one proper interval on either side.
+ // Reject. Must have at least one proper interval on either side.
+ if (isIndex &&
+ (!hasProperIntervals(splitResult._leftReqs) ||
+ !hasProperIntervals(splitResult._rightReqs))) {
continue;
}
+
if (!splitResult._hasFieldCoverage) {
// Reject rewrite. No suitable indexes.
continue;
@@ -1196,11 +1190,8 @@ struct ExploreConvert<SargableNode> {
isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek,
scanDelegator);
- ABT newRoot = make<RIDIntersectNode>(scanProjectionName,
- splitResult._hasLeftIntervals,
- splitResult._hasRightIntervals,
- std::move(leftChild),
- std::move(rightChild));
+ ABT newRoot = make<RIDIntersectNode>(
+ scanProjectionName, std::move(leftChild), std::move(rightChild));
const auto& result = ctx.addNode(newRoot, false /*substitute*/);
for (const MemoLogicalNodeId nodeId : result.second) {
@@ -1290,14 +1281,27 @@ void reorderAgainstRIDIntersectNode(ABT::reference_type aboveNode,
}
const RIDIntersectNode& node = *belowNode.cast<RIDIntersectNode>();
- if (node.hasLeftIntervals() && hasLeftRef) {
+ const GroupIdType groupIdLeft =
+ node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+ const bool hasProperIntervalLeft =
+ properties::getPropertyConst<properties::IndexingAvailability>(
+ ctx.getMemo().getLogicalProps(groupIdLeft))
+ .hasProperInterval();
+ if (hasProperIntervalLeft && hasLeftRef) {
defaultReorder<AboveNode,
RIDIntersectNode,
DefaultChildAccessor,
LeftChildAccessor,
false /*substitute*/>(aboveNode, belowNode, ctx);
}
- if (node.hasRightIntervals() && hasRightRef) {
+
+ const GroupIdType groupIdRight =
+ node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId();
+ const bool hasProperIntervalRight =
+ properties::getPropertyConst<properties::IndexingAvailability>(
+ ctx.getMemo().getLogicalProps(groupIdRight))
+ .hasProperInterval();
+ if (hasProperIntervalRight && hasRightRef) {
defaultReorder<AboveNode,
RIDIntersectNode,
DefaultChildAccessor,
diff --git a/src/mongo/db/query/optimizer/cascades/memo_defs.h b/src/mongo/db/query/optimizer/cascades/memo_defs.h
index b667df2816b..f0e476ab819 100644
--- a/src/mongo/db/query/optimizer/cascades/memo_defs.h
+++ b/src/mongo/db/query/optimizer/cascades/memo_defs.h
@@ -59,6 +59,10 @@ public:
OrderPreservingABTSet(const OrderPreservingABTSet&) = delete;
OrderPreservingABTSet(OrderPreservingABTSet&&) = default;
+ OrderPreservingABTSet& operator=(const OrderPreservingABTSet&) = delete;
+ OrderPreservingABTSet& operator=(OrderPreservingABTSet&&) = default;
+
+
ABT::reference_type at(size_t index) const;
std::pair<size_t, bool> emplace_back(ABT node);
std::pair<size_t, bool> find(ABT::reference_type node) const;
diff --git a/src/mongo/db/query/optimizer/explain.cpp b/src/mongo/db/query/optimizer/explain.cpp
index ba5aecdbd73..5e9df5df3bb 100644
--- a/src/mongo/db/query/optimizer/explain.cpp
+++ b/src/mongo/db/query/optimizer/explain.cpp
@@ -496,7 +496,7 @@ private:
void addValue(sbe::value::TypeTags tag, sbe::value::Value val, const bool append = false) {
if (!_initialized) {
_initialized = true;
- _canAppend = !_nextFieldName.empty();
+ _canAppend = _nextFieldName.has_value();
if (_canAppend) {
std::tie(_tag, _val) = sbe::value::makeNewObject();
} else {
@@ -512,7 +512,7 @@ private:
}
if (append) {
- uassert(6624073, "Field name is not empty", _nextFieldName.empty());
+ uassert(6624073, "Field name is not set", !_nextFieldName.has_value());
uassert(6624349,
"Other printer does not contain Object",
tag == sbe::value::TypeTags::Object);
@@ -523,19 +523,19 @@ private:
addField(obj->field(i), fieldTag, fieldVal);
}
} else {
- addField(_nextFieldName, tag, val);
- _nextFieldName.clear();
+ tassert(6751700, "Missing field name to serialize", _nextFieldName);
+ addField(*_nextFieldName, tag, val);
+ _nextFieldName = boost::none;
}
}
void addField(const std::string& fieldName, sbe::value::TypeTags tag, sbe::value::Value val) {
- uassert(6624074, "Field name is empty", !fieldName.empty());
uassert(6624075, "Duplicate field name", _fieldNameSet.insert(fieldName).second);
sbe::value::getObjectView(_val)->push_back(fieldName, tag, val);
}
void reset() {
- _nextFieldName.clear();
+ _nextFieldName = boost::none;
_initialized = false;
_canAppend = false;
_tag = sbe::value::TypeTags::Nothing;
@@ -543,7 +543,8 @@ private:
_fieldNameSet.clear();
}
- std::string _nextFieldName;
+ // Cannot assume empty means non-existent, so use optional<>.
+ boost::optional<std::string> _nextFieldName;
bool _initialized;
bool _canAppend;
sbe::value::TypeTags _tag;
@@ -1276,8 +1277,6 @@ public:
printer.separator(" [")
.fieldName("scanProjectionName", ExplainVersion::V3)
.print(node.getScanProjectionName());
- printBooleanFlag(printer, "hasLeftIntervals", node.hasLeftIntervals());
- printBooleanFlag(printer, "hasRightIntervals", node.hasRightIntervals());
printer.separator("]");
nodeCEPropsPrint(printer, n, node);
@@ -1766,6 +1765,7 @@ public:
.fieldName("scanDefName")
.print(prop.getScanDefName());
printBooleanFlag(printer, "eqPredsOnly", prop.getEqPredsOnly());
+ printBooleanFlag(printer, "hasProperInterval", prop.hasProperInterval());
printer.separator("]");
if (!prop.getSatisfiedPartialIndexes().empty()) {
@@ -2263,7 +2263,7 @@ public:
ExplainPrinter printer("PathGet");
printer.separator(" [")
.fieldName("path", ExplainVersion::V3)
- .print(path.name())
+ .print(path.name().empty() ? "<empty>" : path.name())
.separator("]")
.setChildCount(1)
.fieldName("input", ExplainVersion::V3)
@@ -2542,8 +2542,12 @@ static void printBSONstr(PrinterType& printer,
}
}
-std::string ExplainGenerator::printBSON(const sbe::value::TypeTags tag,
- const sbe::value::Value val) {
+std::string ExplainGenerator::explainBSONStr(const ABT& node,
+ bool displayProperties,
+ const cascades::MemoExplainInterface* memoInterface,
+ const NodeToGroupPropsMap& nodeMap) {
+ const auto [tag, val] = explainBSON(node, displayProperties, memoInterface, nodeMap);
+ sbe::value::ValueGuard vg(tag, val);
ExplainPrinterImpl<ExplainVersion::V2> printer;
printBSONstr(printer, tag, val);
return printer.str();
diff --git a/src/mongo/db/query/optimizer/explain.h b/src/mongo/db/query/optimizer/explain.h
index ad62dd54126..19e9221f8dd 100644
--- a/src/mongo/db/query/optimizer/explain.h
+++ b/src/mongo/db/query/optimizer/explain.h
@@ -94,7 +94,10 @@ public:
const cascades::MemoExplainInterface* memoInterface = nullptr,
const NodeToGroupPropsMap& nodeMap = {});
- static std::string printBSON(sbe::value::TypeTags tag, sbe::value::Value val);
+ static std::string explainBSONStr(const ABT& node,
+ bool displayProperties = false,
+ const cascades::MemoExplainInterface* memoInterface = nullptr,
+ const NodeToGroupPropsMap& nodeMap = {});
static std::string explainLogicalProps(const std::string& description,
const properties::LogicalProps& props);
diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_intersection_test.cpp
index 119a1b98fea..cc8127d49b7 100644
--- a/src/mongo/db/query/optimizer/interval_intersection_test.cpp
+++ b/src/mongo/db/query/optimizer/interval_intersection_test.cpp
@@ -56,12 +56,12 @@ std::string optimizedQueryPlan(const std::string& query,
ABT translated =
translatePipeline(metadata, "[{$match: " + query + "}]", scanDefName, prefixId);
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase,
- OptPhase::MemoExplorationPhase,
- OptPhase::MemoImplementationPhase},
- prefixId,
- metadata,
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase,
+ OptPhase::MemoExplorationPhase,
+ OptPhase::MemoImplementationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
ABT optimized = translated;
phaseManager.getHints()._disableScan = true;
diff --git a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp
index 2434e811de1..03980eb5db4 100644
--- a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp
@@ -74,10 +74,10 @@ TEST(LogicalRewriter, RootNodeMerge) {
" Source []\n",
rootNode);
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT rewritten = std::move(rootNode);
phaseManager.optimize(rewritten);
@@ -288,10 +288,10 @@ TEST(LogicalRewriter, FilterProjectRewrite) {
" Source []\n",
rootNode);
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -399,10 +399,10 @@ TEST(LogicalRewriter, FilterProjectComplexRewrite) {
" Source []\n",
rootNode);
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -477,10 +477,10 @@ TEST(LogicalRewriter, FilterProjectGroupRewrite) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"c"}},
std::move(filterANode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -547,10 +547,10 @@ TEST(LogicalRewriter, FilterProjectUnwindRewrite) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
std::move(filterBNode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -618,10 +618,10 @@ TEST(LogicalRewriter, FilterProjectExchangeRewrite) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
std::move(filterANode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -690,10 +690,10 @@ TEST(LogicalRewriter, UnwindCollationRewrite) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}},
std::move(unwindNode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -802,11 +802,11 @@ TEST(LogicalRewriter, FilterUnionReorderSingleProjection) {
" Source []\n",
latest);
- OptPhaseManager phaseManager(
- {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
- prefixId,
- {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager =
+ makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
phaseManager.optimize(latest);
ASSERT_EXPLAIN_V2(
@@ -966,11 +966,11 @@ TEST(LogicalRewriter, MultipleFilterUnionReorder) {
" Source []\n",
latest);
- OptPhaseManager phaseManager(
- {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
- prefixId,
- {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager =
+ makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
phaseManager.optimize(latest);
ASSERT_EXPLAIN_V2(
@@ -1070,12 +1070,12 @@ TEST(LogicalRewriter, FilterUnionUnionPushdown) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}},
std::move(filter));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test1", createScanDef({}, {})},
- {"test2", createScanDef({}, {})},
- {"test3", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test1", createScanDef({}, {})},
+ {"test2", createScanDef({}, {})},
+ {"test3", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
ASSERT_EXPLAIN_V2(
@@ -1216,10 +1216,11 @@ TEST(LogicalRewriter, UnionPreservesCommonLogicalProps) {
// Run the reordering rewrite such that the scan produces a hash partition.
PrefixId prefixId;
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
- prefixId,
- metadata,
- DebugInfo::kDefaultForTests);
+ auto phaseManager =
+ makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
+ prefixId,
+ metadata,
+ DebugInfo::kDefaultForTests);
ABT optimized = rootNode;
phaseManager.optimize(optimized);
@@ -1432,10 +1433,11 @@ TEST(LogicalRewriter, SargableCE) {
PrefixId prefixId;
ABT rootNode = sargableCETestSetup();
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager =
+ makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
@@ -1474,7 +1476,8 @@ TEST(LogicalRewriter, SargableCE) {
" | | projections: \n"
" | | ptest\n"
" | | indexingAvailability: \n"
- " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly]\n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly, "
+ "hasProperInterval]\n"
" | | collectionAvailability: \n"
" | | test\n"
" | | distributionAvailability: \n"
@@ -1508,7 +1511,8 @@ TEST(LogicalRewriter, SargableCE) {
" | | projections: \n"
" | | ptest\n"
" | | indexingAvailability: \n"
- " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly]\n"
+ " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly, "
+ "hasProperInterval]\n"
" | | collectionAvailability: \n"
" | | test\n"
" | | distributionAvailability: \n"
@@ -1540,10 +1544,10 @@ TEST(LogicalRewriter, RemoveNoopFilter) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}},
std::move(filterANode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo::kDefaultForTests);
ABT latest = std::move(rootNode);
phaseManager.optimize(latest);
diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp
index d9b0086fb1e..1eed39980b1 100644
--- a/src/mongo/db/query/optimizer/node.cpp
+++ b/src/mongo/db/query/optimizer/node.cpp
@@ -290,15 +290,9 @@ bool EvaluationNode::operator==(const EvaluationNode& other) const {
getChild() == other.getChild();
}
-RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName,
- const bool hasLeftIntervals,
- const bool hasRightIntervals,
- ABT leftChild,
- ABT rightChild)
+RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName, ABT leftChild, ABT rightChild)
: Base(std::move(leftChild), std::move(rightChild)),
- _scanProjectionName(std::move(scanProjectionName)),
- _hasLeftIntervals(hasLeftIntervals),
- _hasRightIntervals(hasRightIntervals) {
+ _scanProjectionName(std::move(scanProjectionName)) {
assertNodeSort(getLeftChild());
assertNodeSort(getRightChild());
}
@@ -321,23 +315,13 @@ ABT& RIDIntersectNode::getRightChild() {
bool RIDIntersectNode::operator==(const RIDIntersectNode& other) const {
return _scanProjectionName == other._scanProjectionName &&
- _hasLeftIntervals == other._hasLeftIntervals &&
- _hasRightIntervals == other._hasRightIntervals && getLeftChild() == other.getLeftChild() &&
- getRightChild() == other.getRightChild();
+ getLeftChild() == other.getLeftChild() && getRightChild() == other.getRightChild();
}
const ProjectionName& RIDIntersectNode::getScanProjectionName() const {
return _scanProjectionName;
}
-bool RIDIntersectNode::hasLeftIntervals() const {
- return _hasLeftIntervals;
-}
-
-bool RIDIntersectNode::hasRightIntervals() const {
- return _hasRightIntervals;
-}
-
static ProjectionNameVector createSargableBindings(const PartialSchemaRequirements& reqMap) {
ProjectionNameVector result;
for (const auto& entry : reqMap) {
diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h
index 155f1d69f29..880d2fd5480 100644
--- a/src/mongo/db/query/optimizer/node.h
+++ b/src/mongo/db/query/optimizer/node.h
@@ -389,11 +389,7 @@ class RIDIntersectNode final : public Operator<2>, public ExclusivelyLogicalNode
using Base = Operator<2>;
public:
- RIDIntersectNode(ProjectionName scanProjectionName,
- bool hasLeftIntervals,
- bool hasRightIntervals,
- ABT leftChild,
- ABT rightChild);
+ RIDIntersectNode(ProjectionName scanProjectionName, ABT leftChild, ABT rightChild);
bool operator==(const RIDIntersectNode& other) const;
@@ -405,15 +401,8 @@ public:
const ProjectionName& getScanProjectionName() const;
- bool hasLeftIntervals() const;
- bool hasRightIntervals() const;
-
private:
const ProjectionName _scanProjectionName;
-
- // If true left and right children have at least one proper interval (not fully open).
- const bool _hasLeftIntervals;
- const bool _hasRightIntervals;
};
/**
diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.cpp b/src/mongo/db/query/optimizer/opt_phase_manager.cpp
index 4c4fd0deed3..a00e30aa499 100644
--- a/src/mongo/db/query/optimizer/opt_phase_manager.cpp
+++ b/src/mongo/db/query/optimizer/opt_phase_manager.cpp
@@ -30,7 +30,6 @@
#include "mongo/db/query/optimizer/opt_phase_manager.h"
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/cascades/logical_props_derivation.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
#include "mongo/db/query/optimizer/rewrites/path.h"
@@ -49,22 +48,6 @@ OptPhaseManager::PhaseSet OptPhaseManager::_allRewrites = {OptPhase::ConstEvalPr
OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet,
PrefixId& prefixId,
- Metadata metadata,
- DebugInfo debugInfo,
- QueryHints queryHints)
- : OptPhaseManager(std::move(phaseSet),
- prefixId,
- false /*requireRID*/,
- std::move(metadata),
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
- std::move(debugInfo),
- std::move(queryHints)) {}
-
-OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet,
- PrefixId& prefixId,
const bool requireRID,
Metadata metadata,
std::unique_ptr<CEInterface> ceDerivation,
diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.h b/src/mongo/db/query/optimizer/opt_phase_manager.h
index 0a94ffb9761..81a849a3461 100644
--- a/src/mongo/db/query/optimizer/opt_phase_manager.h
+++ b/src/mongo/db/query/optimizer/opt_phase_manager.h
@@ -77,11 +77,6 @@ public:
OptPhaseManager(PhaseSet phaseSet,
PrefixId& prefixId,
- Metadata metadata,
- DebugInfo debugInfo,
- QueryHints queryHints = {});
- OptPhaseManager(PhaseSet phaseSet,
- PrefixId& prefixId,
bool requireRID,
Metadata metadata,
std::unique_ptr<CEInterface> ceDerivation,
diff --git a/src/mongo/db/query/optimizer/optimizer_failure_test.cpp b/src/mongo/db/query/optimizer/optimizer_failure_test.cpp
index bf4715548aa..6ab998dcf13 100644
--- a/src/mongo/db/query/optimizer/optimizer_failure_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_failure_test.cpp
@@ -28,7 +28,6 @@
*/
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/metadata_factory.h"
#include "mongo/db/query/optimizer/node.h"
#include "mongo/db/query/optimizer/opt_phase_manager.h"
@@ -53,11 +52,11 @@ DEATH_TEST_REGEX(Optimizer, HitIterationLimitInrunStructuralPhases, "Tripwire as
ABT evalNode = make<EvaluationNode>("evalProj1", Constant::int64(5), std::move(scanNode));
- OptPhaseManager phaseManager(
- {OptPhase::PathFuse, OptPhase::ConstEvalPre},
- prefixId,
- {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
- DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
+ auto phaseManager =
+ makePhaseManager({OptPhase::PathFuse, OptPhase::ConstEvalPre},
+ prefixId,
+ {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}},
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
ASSERT_THROWS_CODE(phaseManager.optimize(evalNode), DBException, 6808700);
}
@@ -81,10 +80,10 @@ DEATH_TEST_REGEX(Optimizer,
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode));
- OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
+ auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808702);
}
@@ -108,10 +107,10 @@ DEATH_TEST_REGEX(Optimizer,
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode));
- OptPhaseManager phaseManager({OptPhase::MemoExplorationPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
+ auto phaseManager = makePhaseManager({OptPhase::MemoExplorationPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808702);
}
@@ -133,10 +132,10 @@ DEATH_TEST_REGEX(Optimizer, BadGroupID, "Tripwire assertion.*6808704") {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode));
- OptPhaseManager phaseManager({OptPhase::MemoImplementationPhase},
- prefixId,
- {{{"test", createScanDef({}, {})}}},
- DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
+ auto phaseManager = makePhaseManager({OptPhase::MemoImplementationPhase},
+ prefixId,
+ {{{"test", createScanDef({}, {})}}},
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0));
ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808704);
}
@@ -161,13 +160,13 @@ DEATH_TEST_REGEX(Optimizer, EnvHasFreeVariables, "Tripwire assertion.*6808711")
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p3"}}, std::move(filter2Node));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
{{{"test", createScanDef({}, {})}}},
- {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests));
ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808711);
}
@@ -206,12 +205,11 @@ DEATH_TEST_REGEX(Optimizer, FailedToRetrieveRID, "Tripwire assertion.*6808705")
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManagerRequireRID(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- true /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -224,11 +222,7 @@ DEATH_TEST_REGEX(Optimizer, FailedToRetrieveRID, "Tripwire assertion.*6808705")
ConstEval::constFold,
{DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))})}},
5 /*numberOfPartitions*/},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
- {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
+ DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests));
ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808705);
}
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
index d52b84470a1..423a6a399af 100644
--- a/src/mongo/db/query/optimizer/optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -29,18 +29,45 @@
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/node.h"
-#include "mongo/db/query/optimizer/opt_phase_manager.h"
#include "mongo/db/query/optimizer/reference_tracker.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
#include "mongo/db/query/optimizer/syntax/syntax.h"
-#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h"
#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/db/query/optimizer/utils/utils.h"
#include "mongo/unittest/unittest.h"
+
namespace mongo::optimizer {
namespace {
+TEST(Optimizer, AutoUpdateExplain) {
+ ABT tree = make<BinaryOp>(Operations::Add,
+ Constant::int64(1),
+ make<Variable>("very very very very very very very very very very "
+ "very very long variable name with \"quotes\""));
+
+ /**
+ * To exercise the auto-updating behavior:
+ * 1. Change the flag "kAutoUpdateOnFailure" to "true".
+ * 2. Induce a failure: change something in the expected output.
+ * 3. Recompile and run the test binary as normal.
+ * 4. Observe after the run the test file is updated with the correct output.
+ */
+ ASSERT_EXPLAIN_V2_AUTO( // NOLINT (test auto-update)
+ "BinaryOp [Add]\n"
+ "| Variable [very very very very very very very very very very very very long variable "
+ "name with \"quotes\"]\n"
+ "Const [1]\n",
+ tree);
+
+ // Test for short constant. It should not be inlined. The nolint comment on the string constant
+ // itself is auto-generated.
+ ABT tree1 = make<Variable>("short name");
+ ASSERT_EXPLAIN_V2_AUTO( // NOLINT (test auto-update)
+ "Variable [short name]\n", // NOLINT (test auto-update)
+ tree1);
+}
+
Constant* constEval(ABT& tree) {
auto env = VariableEnvironment::build(tree);
ConstEval evaluator{env};
@@ -746,13 +773,14 @@ TEST(Explain, ExplainV2Compact) {
TEST(Explain, ExplainBsonForConstant) {
ABT cNode = Constant::int64(3);
- auto [tag, val] = ExplainGenerator::explainBSON(cNode);
- sbe::value::ValueGuard vg(tag, val);
- ASSERT_EQ(
- "{\n nodeType: \"Const\", \n"
+
+ ASSERT_EXPLAIN_BSON(
+ "{\n"
+ " nodeType: \"Const\", \n"
" tag: \"NumberInt64\", \n"
- " value: 3\n}\n",
- ExplainGenerator::printBSON(tag, val));
+ " value: 3\n"
+ "}\n",
+ cNode);
}
} // namespace
diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp
index 1079e45acfa..789bcdeea81 100644
--- a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp
@@ -30,7 +30,6 @@
#include "mongo/db/pipeline/abt/utils.h"
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
#include "mongo/db/query/optimizer/cascades/ce_hinted.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/cascades/rewriter_rules.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/metadata_factory.h"
@@ -66,7 +65,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p2"}}, std::move(filter2Node));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -121,7 +120,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) {
"| | p1\n"
"| | p2\n"
"| | indexingAvailability: \n"
- "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test, hasProperInterval]\n"
"| | collectionAvailability: \n"
"| | test\n"
"| | distributionAvailability: \n"
@@ -147,7 +146,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) {
"| | p1\n"
"| | p2\n"
"| | indexingAvailability: \n"
- "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n"
+ "| | [groupId: 0, scanProjection: p1, scanDefName: test, hasProperInterval]\n"
"| | collectionAvailability: \n"
"| | test\n"
"| | distributionAvailability: \n"
@@ -271,7 +270,7 @@ TEST(PhysRewriter, GroupBy) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"c"}}, std::move(filterANode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -341,7 +340,7 @@ TEST(PhysRewriter, GroupBy1) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -413,7 +412,7 @@ TEST(PhysRewriter, Unwind) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"a", "b"}},
std::move(filterBNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -491,7 +490,7 @@ TEST(PhysRewriter, DuplicateFilter) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -552,7 +551,7 @@ TEST(PhysRewriter, FilterCollation) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(limitSkipNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -608,7 +607,7 @@ TEST(PhysRewriter, EvalCollation) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -663,7 +662,7 @@ TEST(PhysRewriter, FilterEvalCollation) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -719,7 +718,7 @@ TEST(PhysRewriter, FilterIndexing) {
{
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase},
prefixId,
{{{"c1",
@@ -739,7 +738,7 @@ TEST(PhysRewriter, FilterIndexing) {
"| | root\n"
"| RefBlock: \n"
"| Variable [root]\n"
- "RIDIntersect [root, hasLeftIntervals]\n"
+ "RIDIntersect [root]\n"
"| Scan [c1]\n"
"| BindBlock:\n"
"| [root]\n"
@@ -762,7 +761,7 @@ TEST(PhysRewriter, FilterIndexing) {
{
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -804,7 +803,7 @@ TEST(PhysRewriter, FilterIndexing) {
{
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -868,7 +867,7 @@ TEST(PhysRewriter, FilterIndexing1) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p1"}}, std::move(filterNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -933,7 +932,7 @@ TEST(PhysRewriter, FilterIndexing2) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1013,7 +1012,7 @@ TEST(PhysRewriter, FilterIndexing2NonSarg) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1131,7 +1130,7 @@ TEST(PhysRewriter, FilterIndexing3) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1186,7 +1185,7 @@ TEST(PhysRewriter, FilterIndexing3MultiKey) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1280,7 +1279,7 @@ TEST(PhysRewriter, FilterIndexing4) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterDNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1388,7 +1387,7 @@ TEST(PhysRewriter, FilterIndexing5) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1478,7 +1477,7 @@ TEST(PhysRewriter, FilterIndexing6) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1541,7 +1540,7 @@ TEST(PhysRewriter, FilterIndexingStress) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1641,7 +1640,7 @@ TEST(PhysRewriter, FilterIndexingVariable) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1737,7 +1736,7 @@ TEST(PhysRewriter, FilterIndexingMaxKey) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -1805,7 +1804,7 @@ TEST(PhysRewriter, SargableProjectionRenames) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(evalNode2));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase},
prefixId,
{{{"c1", createScanDef({}, {})}}},
@@ -1867,7 +1866,7 @@ TEST(PhysRewriter, SargableAcquireProjection) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(evalNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase},
prefixId,
{{{"c1", createScanDef({}, {})}}},
@@ -1933,17 +1932,13 @@ TEST(PhysRewriter, FilterReorder) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1", createScanDef({}, {})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = std::move(rootNode);
@@ -2028,21 +2023,17 @@ TEST(PhysRewriter, CoveredScan) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
{{"index1",
makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = std::move(rootNode);
@@ -2100,7 +2091,7 @@ TEST(PhysRewriter, EvalIndexing) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
{
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2133,7 +2124,7 @@ TEST(PhysRewriter, EvalIndexing) {
{
// Index and collation node have incompatible ops.
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2192,7 +2183,7 @@ TEST(PhysRewriter, EvalIndexing1) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2262,7 +2253,7 @@ TEST(PhysRewriter, EvalIndexing2) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa2"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::ConstEvalPre,
OptPhase::PathFuse,
OptPhase::MemoSubstitutionPhase,
@@ -2347,12 +2338,11 @@ TEST(PhysRewriter, MultiKeyIndex) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -2360,9 +2350,6 @@ TEST(PhysRewriter, MultiKeyIndex) {
{"index2",
makeIndexDefinition("b", CollationOp::Descending, false /*isMultiKey*/)}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
{
@@ -2571,7 +2558,7 @@ TEST(PhysRewriter, CompoundIndex1) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterDNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2656,7 +2643,7 @@ TEST(PhysRewriter, CompoundIndex2) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2743,7 +2730,7 @@ TEST(PhysRewriter, CompoundIndex3) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2830,12 +2817,11 @@ TEST(PhysRewriter, CompoundIndex4Negative) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode));
// Create the following indexes: {a:1, c:1, {name: 'index1'}}, and {b:1, d:1, {name: 'index2'}}
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -2848,9 +2834,6 @@ TEST(PhysRewriter, CompoundIndex4Negative) {
{makeNonMultikeyIndexPath("d"), CollationOp::Ascending}},
false /*isMultiKey*/}}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -2905,7 +2888,7 @@ TEST(PhysRewriter, IndexBoundsIntersect) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -2980,7 +2963,7 @@ TEST(PhysRewriter, IndexBoundsIntersect1) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3048,7 +3031,7 @@ TEST(PhysRewriter, IndexBoundsIntersect2) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3122,7 +3105,7 @@ TEST(PhysRewriter, IndexBoundsIntersect3) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3201,7 +3184,7 @@ TEST(PhysRewriter, IndexResidualReq) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3229,7 +3212,7 @@ TEST(PhysRewriter, IndexResidualReq) {
"| | pa\n"
"| | root\n"
"| | indexingAvailability: \n"
- "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1, hasProperInterval]\n"
"| | collectionAvailability: \n"
"| | c1\n"
"| | distributionAvailability: \n"
@@ -3257,7 +3240,7 @@ TEST(PhysRewriter, IndexResidualReq) {
"| | pa\n"
"| | root\n"
"| | indexingAvailability: \n"
- "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n"
+ "| | [groupId: 0, scanProjection: root, scanDefName: c1, hasProperInterval]\n"
"| | collectionAvailability: \n"
"| | c1\n"
"| | distributionAvailability: \n"
@@ -3321,7 +3304,7 @@ TEST(PhysRewriter, IndexResidualReq1) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(collationNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3404,7 +3387,7 @@ TEST(PhysRewriter, IndexResidualReq2) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -3483,18 +3466,13 @@ TEST(PhysRewriter, ElemMatchIndex) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef({}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -3567,22 +3545,17 @@ TEST(PhysRewriter, ElemMatchIndex1) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef({},
{{"index1",
makeCompositeIndexDefinition(
{{"b", CollationOp::Ascending, true /*isMultiKey*/},
{"a", CollationOp::Ascending, true /*isMultiKey*/}})}})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -3650,21 +3623,16 @@ TEST(PhysRewriter, ElemMatchIndexNoArrays) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
{{"index1",
makeIndexDefinition("a", CollationOp::Ascending, false /*multiKey*/)}})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -3720,22 +3688,17 @@ TEST(PhysRewriter, ObjectElemMatchResidual) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef({},
{{"index1",
makeCompositeIndexDefinition(
{{"b", CollationOp::Ascending, true /*isMultiKey*/},
{"a", CollationOp::Ascending, true /*isMultiKey*/}})}})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -3840,12 +3803,11 @@ TEST(PhysRewriter, ObjectElemMatchBounds) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -3855,10 +3817,6 @@ TEST(PhysRewriter, ObjectElemMatchBounds) {
true /*isMultiKey*/}}}
)}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -3928,21 +3886,16 @@ TEST(PhysRewriter, NestedElemMatch) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"coll1",
createScanDef(
{},
{{"index1",
makeIndexDefinition("a", CollationOp::Ascending, true /*isMultiKey*/)}})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- defaultConvertPathToInterval,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -4049,12 +4002,11 @@ TEST(PhysRewriter, PathObj) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
PrefixId prefixId;
- OptPhaseManager phaseManager{
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef({},
{{"index1",
@@ -4062,13 +4014,8 @@ TEST(PhysRewriter, PathObj) {
{{"a", CollationOp::Ascending, false /*isMultiKey*/},
{"b", CollationOp::Ascending, true /*isMultiKey*/}})}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- // The path-to-interval callback is important for this test.
- // We want to confirm PathObj becomes an interval.
- defaultConvertPathToInterval,
- ConstEval::constFold,
DebugInfo{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests},
- {} /*hints*/};
+ {} /*hints*/);
ABT optimized = rootNode;
phaseManager.optimize(optimized);
@@ -4139,7 +4086,7 @@ TEST(PhysRewriter, ArrayConstantIndex) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(filterNode2));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4244,7 +4191,7 @@ TEST(PhysRewriter, ArrayConstantNoIndex) {
ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(filterNode2));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4306,7 +4253,7 @@ TEST(PhysRewriter, ParallelScan) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4368,7 +4315,7 @@ TEST(PhysRewriter, HashPartitioning) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4452,12 +4399,11 @@ TEST(PhysRewriter, IndexPartitioning) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -4471,9 +4417,6 @@ TEST(PhysRewriter, IndexPartitioning) {
{DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))})}},
5 /*numberOfPartitions*/},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -4573,12 +4516,11 @@ TEST(PhysRewriter, IndexPartitioning1) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -4598,9 +4540,6 @@ TEST(PhysRewriter, IndexPartitioning1) {
{DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("c"))})}},
5 /*numberOfPartitions*/},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -4653,7 +4592,7 @@ TEST(PhysRewriter, LocalGlobalAgg) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pc"}},
std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4731,7 +4670,7 @@ TEST(PhysRewriter, LocalGlobalAgg1) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4786,7 +4725,7 @@ TEST(PhysRewriter, LocalLimitSkip) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(limitSkipNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -4920,7 +4859,7 @@ TEST(PhysRewriter, CollationLimit) {
ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}},
std::move(limitSkipNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5061,7 +5000,7 @@ TEST(PhysRewriter, PartialIndex1) {
ASSERT_TRUE(conversionResult);
ASSERT_FALSE(conversionResult->_retainPredicate);
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5142,7 +5081,7 @@ TEST(PhysRewriter, PartialIndex2) {
ASSERT_TRUE(conversionResult);
ASSERT_FALSE(conversionResult->_retainPredicate);
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5222,7 +5161,7 @@ TEST(PhysRewriter, PartialIndexReject) {
ASSERT_TRUE(conversionResult);
ASSERT_FALSE(conversionResult->_retainPredicate);
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5287,17 +5226,12 @@ TEST(PhysRewriter, RequireRID) {
ABT rootNode =
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManagerRequireRID(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- true /*requireRID*/,
{{{"c1", createScanDef({}, {})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -5342,17 +5276,12 @@ TEST(PhysRewriter, RequireRID1) {
std::move(filterNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManagerRequireRID(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- true /*requireRID*/,
{{{"c1", createScanDef({}, {})}}},
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -5410,7 +5339,7 @@ TEST(PhysRewriter, UnionRewrite) {
std::move(unionNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5479,7 +5408,7 @@ TEST(PhysRewriter, JoinRewrite) {
std::move(joinNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5554,7 +5483,7 @@ TEST(PhysRewriter, JoinRewrite1) {
std::move(joinNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5609,7 +5538,7 @@ TEST(PhysRewriter, RootInterval) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5663,7 +5592,7 @@ TEST(PhysRewriter, EqMemberSargable) {
{
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase},
prefixId,
{{{"c1",
@@ -5711,7 +5640,7 @@ TEST(PhysRewriter, EqMemberSargable) {
{
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5811,7 +5740,7 @@ TEST(PhysRewriter, IndexSubfieldCovered) {
std::move(filterNode3));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
@@ -5895,12 +5824,11 @@ TEST(PhysRewriter, PerfOnlyPreds1) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode2));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -5909,9 +5837,6 @@ TEST(PhysRewriter, PerfOnlyPreds1) {
{"a", CollationOp::Ascending, false /*isMultiKey*/}},
false /*isMultiKey*/)}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
@@ -5988,12 +5913,11 @@ TEST(PhysRewriter, PerfOnlyPreds2) {
make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode2));
PrefixId prefixId;
- OptPhaseManager phaseManager(
+ auto phaseManager = makePhaseManager(
{OptPhase::MemoSubstitutionPhase,
OptPhase::MemoExplorationPhase,
OptPhase::MemoImplementationPhase},
prefixId,
- false /*requireRID*/,
{{{"c1",
createScanDef(
{},
@@ -6001,15 +5925,12 @@ TEST(PhysRewriter, PerfOnlyPreds2) {
{"index2",
makeIndexDefinition("b", CollationOp::Ascending, false /*isMultiKey*/)}})}}},
std::make_unique<HintedCE>(std::move(hints)),
- std::make_unique<DefaultCosting>(),
- {} /*pathToInterval*/,
- ConstEval::constFold,
{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests});
ABT optimized = rootNode;
phaseManager.getHints()._disableYieldingTolerantPlans = false;
phaseManager.optimize(optimized);
- ASSERT_BETWEEN(10, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount);
+ ASSERT_BETWEEN(10, 17, phaseManager.getMemo().getStats()._physPlanExplorationCount);
// Demonstrate an intersection plan, with predicates repeated on the Seek side.
ASSERT_EXPLAIN_V2Compact(
@@ -6042,16 +5963,16 @@ TEST(PhysRewriter, PerfOnlyPreds2) {
"| Variable [rid_0]\n"
"MergeJoin []\n"
"| | | Condition\n"
- "| | | rid_0 = rid_3\n"
+ "| | | rid_0 = rid_5\n"
"| | Collation\n"
"| | Ascending\n"
"| Union []\n"
"| | BindBlock:\n"
- "| | [rid_3]\n"
+ "| | [rid_5]\n"
"| | Source []\n"
"| Evaluation []\n"
"| | BindBlock:\n"
- "| | [rid_3]\n"
+ "| | [rid_5]\n"
"| | Variable [rid_0]\n"
"| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const "
"[2], Const [2]]}]\n"
diff --git a/src/mongo/db/query/optimizer/props.cpp b/src/mongo/db/query/optimizer/props.cpp
index 854882292be..c0079a78ed2 100644
--- a/src/mongo/db/query/optimizer/props.cpp
+++ b/src/mongo/db/query/optimizer/props.cpp
@@ -282,17 +282,20 @@ IndexingAvailability::IndexingAvailability(GroupIdType scanGroupId,
ProjectionName scanProjection,
std::string scanDefName,
const bool eqPredsOnly,
+ const bool hasProperInterval,
opt::unordered_set<std::string> satisfiedPartialIndexes)
: _scanGroupId(scanGroupId),
_scanProjection(std::move(scanProjection)),
_scanDefName(std::move(scanDefName)),
_eqPredsOnly(eqPredsOnly),
- _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)) {}
+ _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)),
+ _hasProperInterval(hasProperInterval) {}
bool IndexingAvailability::operator==(const IndexingAvailability& other) const {
return _scanGroupId == other._scanGroupId && _scanProjection == other._scanProjection &&
_scanDefName == other._scanDefName && _eqPredsOnly == other._eqPredsOnly &&
- _satisfiedPartialIndexes == other._satisfiedPartialIndexes;
+ _satisfiedPartialIndexes == other._satisfiedPartialIndexes &&
+ _hasProperInterval == other._hasProperInterval;
}
GroupIdType IndexingAvailability::getScanGroupId() const {
@@ -327,6 +330,14 @@ void IndexingAvailability::setEqPredsOnly(const bool value) {
_eqPredsOnly = value;
}
+bool IndexingAvailability::hasProperInterval() const {
+ return _hasProperInterval;
+}
+
+void IndexingAvailability::setHasProperInterval(const bool hasProperInterval) {
+ _hasProperInterval = hasProperInterval;
+}
+
CollectionAvailability::CollectionAvailability(opt::unordered_set<std::string> scanDefSet)
: _scanDefSet(std::move(scanDefSet)) {}
diff --git a/src/mongo/db/query/optimizer/props.h b/src/mongo/db/query/optimizer/props.h
index adfb0f82847..e7ac16f227d 100644
--- a/src/mongo/db/query/optimizer/props.h
+++ b/src/mongo/db/query/optimizer/props.h
@@ -399,6 +399,7 @@ public:
ProjectionName scanProjection,
std::string scanDefName,
bool eqPredsOnly,
+ bool hasProperInterval,
opt::unordered_set<std::string> satisfiedPartialIndexes);
bool operator==(const IndexingAvailability& other) const;
@@ -415,6 +416,9 @@ public:
bool getEqPredsOnly() const;
void setEqPredsOnly(bool value);
+ bool hasProperInterval() const;
+ void setHasProperInterval(bool hasProperInterval);
+
private:
GroupIdType _scanGroupId;
const ProjectionName _scanProjection;
@@ -427,6 +431,9 @@ private:
// Set of indexes with partial indexes whose partial filters are satisfied for the current
// group.
opt::unordered_set<std::string> _satisfiedPartialIndexes;
+
+ // True if there is at least one proper interval in a sargable node in this group.
+ bool _hasProperInterval;
};
diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.cpp b/src/mongo/db/query/optimizer/utils/abt_hash.cpp
index 90585939882..b9f8d5a2517 100644
--- a/src/mongo/db/query/optimizer/utils/abt_hash.cpp
+++ b/src/mongo/db/query/optimizer/utils/abt_hash.cpp
@@ -195,8 +195,6 @@ public:
size_t rightChildResult) {
// Specifically always including children.
return computeHashSeq<45>(std::hash<ProjectionName>()(node.getScanProjectionName()),
- std::hash<bool>()(node.hasLeftIntervals()),
- std::hash<bool>()(node.hasRightIntervals()),
leftChildResult,
rightChildResult);
}
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp
index 446d975c16d..8c807453232 100644
--- a/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp
+++ b/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp
@@ -32,9 +32,9 @@
#include "mongo/db/pipeline/abt/document_source_visitor.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
-#include "mongo/db/query/optimizer/cascades/cost_derivation.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/unittest/temp_dir.h"
@@ -212,15 +212,7 @@ ABT optimizeABT(ABT abt,
bool phaseManagerDisableScan) {
PrefixId prefixId;
- OptPhaseManager phaseManager(phaseSet,
- prefixId,
- false,
- metadata,
- std::make_unique<HeuristicCE>(),
- std::make_unique<DefaultCosting>(),
- pathToInterval,
- ConstEval::constFold,
- DebugInfo::kDefaultForTests);
+ auto phaseManager = makePhaseManager(phaseSet, prefixId, metadata, DebugInfo::kDefaultForTests);
if (phaseManagerDisableScan) {
phaseManager.getHints()._disableScan = true;
}
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp
index 5ad7c6615c5..538a4b1f036 100644
--- a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp
+++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp
@@ -29,30 +29,172 @@
#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
+#include <fstream>
+
+#include "mongo/db/pipeline/abt/utils.h"
+#include "mongo/db/query/cost_model/cost_estimator.h"
+#include "mongo/db/query/cost_model/cost_model_manager.h"
+#include "mongo/db/query/optimizer/cascades/ce_heuristic.h"
#include "mongo/db/query/optimizer/explain.h"
#include "mongo/db/query/optimizer/metadata.h"
#include "mongo/db/query/optimizer/node.h"
-#include "mongo/unittest/unittest.h"
+#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/util/str_escape.h"
namespace mongo::optimizer {
static constexpr bool kDebugAsserts = false;
+// DO NOT COMMIT WITH "TRUE".
+static constexpr bool kAutoUpdateOnFailure = false;
+static constexpr const char* kTempFileSuffix = ".tmp.txt";
+
+// Map from file name to a list of updates. We keep track of how many lines are added or deleted at
+// a particular line of a source file.
+using LineDeltaVector = std::vector<std::pair<uint64_t, int64_t>>;
+std::map<std::string, LineDeltaVector> gLineDeltaMap;
+
+
void maybePrintABT(const ABT& abt) {
// Always print using the supported versions to make sure we don't crash.
const std::string strV1 = ExplainGenerator::explain(abt);
const std::string strV2 = ExplainGenerator::explainV2(abt);
const std::string strV2Compact = ExplainGenerator::explainV2Compact(abt);
- auto [tag, val] = ExplainGenerator::explainBSON(abt);
- sbe::value::ValueGuard vg(tag, val);
+ const std::string strBSON = ExplainGenerator::explainBSONStr(abt);
if constexpr (kDebugAsserts) {
std::cout << "V1: " << strV1 << "\n";
std::cout << "V2: " << strV2 << "\n";
std::cout << "V2Compact: " << strV2Compact << "\n";
- std::cout << "BSON: " << ExplainGenerator::printBSON(tag, val) << "\n";
+ std::cout << "BSON: " << strBSON << "\n";
+ }
+}
+
+static std::vector<std::string> formatStr(const std::string& str) {
+ std::vector<std::string> replacementLines;
+ std::istringstream lineInput(str);
+
+ // Account for maximum line length after linting. We need to indent, add quotes, etc.
+ static constexpr size_t kEscapedLength = 88;
+
+ std::string line;
+ while (std::getline(lineInput, line)) {
+ // Read the string line by line and format it to match the test file's expected format. We
+ // have an initial indentation, followed by quotes and the escaped string itself.
+
+ std::string escaped = mongo::str::escapeForJSON(line);
+ for (;;) {
+ // If the line is estimated to exceed the maximum length allowed by the linter, break it
+ // up and make sure to insert '\n' only at the end of the last segment.
+ const bool breakupLine = escaped.size() > kEscapedLength;
+
+ std::ostringstream os;
+ os << " \"" << escaped.substr(0, kEscapedLength);
+ if (!breakupLine) {
+ os << "\\n";
+ }
+ os << "\"\n";
+ replacementLines.push_back(os.str());
+
+ if (breakupLine) {
+ escaped = escaped.substr(kEscapedLength);
+ } else {
+ break;
+ }
+ }
+ }
+
+ if (!replacementLines.empty() && !replacementLines.back().empty()) {
+ // Account for the fact that we need an extra comma after the string constant in the macro.
+ auto& lastLine = replacementLines.back();
+ lastLine.insert(lastLine.size() - 1, ",");
+
+ if (replacementLines.size() == 1) {
+ // For single lines, add a 'nolint' comment to prevent the linter from inlining the
+ // single line with the macro itself.
+ lastLine.insert(lastLine.size() - 1, " // NOLINT (test auto-update)");
+ }
+ }
+
+ return replacementLines;
+}
+
+bool handleAutoUpdate(const std::string& expected,
+ const std::string& actual,
+ const std::string& fileName,
+ const size_t lineNumber) {
+ if (expected == actual) {
+ return true;
+ }
+ if constexpr (!kAutoUpdateOnFailure) {
+ std::cout << "Auto-updating is disabled.\n";
+ return false;
+ }
+
+ const auto expectedFormatted = formatStr(expected);
+ const auto actualFormatted = formatStr(actual);
+
+ std::cout << "Updating expected result in file '" << fileName << "', line: " << lineNumber
+ << ".\n";
+ std::cout << "Replacement:\n";
+ for (const auto& line : actualFormatted) {
+ std::cout << line;
+ }
+
+ // Compute the total number of lines added or removed before the current macro line.
+ auto& lineDeltas = gLineDeltaMap.emplace(fileName, LineDeltaVector{}).first->second;
+ int64_t totalDelta = 0;
+ for (const auto& [line, delta] : lineDeltas) {
+ if (line < lineNumber) {
+ totalDelta += delta;
+ }
}
+
+ const size_t replacementEndLine = lineNumber + totalDelta;
+ // Treat an empty string as needing one line. Adjust for line delta.
+ const size_t replacementStartLine =
+ replacementEndLine - (expectedFormatted.empty() ? 1 : expectedFormatted.size());
+
+ const std::string tempFileName = fileName + kTempFileSuffix;
+ std::string line;
+ size_t lineIndex = 0;
+
+ try {
+ std::ifstream in;
+ in.open(fileName);
+ std::ofstream out;
+ out.open(tempFileName);
+
+ // Generate a new test file, updated with the replacement string.
+ while (std::getline(in, line)) {
+ lineIndex++;
+
+ if (lineIndex < replacementStartLine || lineIndex >= replacementEndLine) {
+ out << line << "\n";
+ } else if (lineIndex == replacementStartLine) {
+ for (const auto& line1 : actualFormatted) {
+ out << line1;
+ }
+ }
+ }
+
+ out.close();
+ in.close();
+
+ std::rename(tempFileName.c_str(), fileName.c_str());
+ } catch (const std::exception& ex) {
+ // Print and re-throw exception.
+ std::cout << "Caught an exception while manipulating files: " << ex.what();
+ throw ex;
+ }
+
+ // Add the current delta.
+ lineDeltas.emplace_back(
+ lineNumber, static_cast<int64_t>(actualFormatted.size()) - expectedFormatted.size());
+
+ // Do not assert in order to allow multiple tests to be updated.
+ return true;
}
ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey) {
@@ -96,4 +238,60 @@ IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFi
return IndexDefinition{std::move(idxCollSpec), isMultiKey};
}
+std::unique_ptr<CostingInterface> makeCosting() {
+ return std::make_unique<cost_model::CostEstimator>(
+ cost_model::CostModelManager().getDefaultCoefficients());
+}
+
+OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ DebugInfo debugInfo,
+ QueryHints queryHints) {
+ return OptPhaseManager{std::move(phaseSet),
+ prefixId,
+ false /*requireRID*/,
+ std::move(metadata),
+ std::make_unique<HeuristicCE>(),
+ makeCosting(),
+ defaultConvertPathToInterval,
+ ConstEval::constFold,
+ std::move(debugInfo),
+ std::move(queryHints)};
+}
+
+OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ std::unique_ptr<CEInterface> ceDerivation,
+ DebugInfo debugInfo,
+ QueryHints queryHints) {
+ return OptPhaseManager{std::move(phaseSet),
+ prefixId,
+ false /*requireRID*/,
+ std::move(metadata),
+ std::move(ceDerivation),
+ makeCosting(),
+ defaultConvertPathToInterval,
+ ConstEval::constFold,
+ std::move(debugInfo),
+ std::move(queryHints)};
+}
+
+OptPhaseManager makePhaseManagerRequireRID(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ DebugInfo debugInfo,
+ QueryHints queryHints) {
+ return OptPhaseManager{std::move(phaseSet),
+ prefixId,
+ true /*requireRID*/,
+ std::move(metadata),
+ std::make_unique<HeuristicCE>(),
+ makeCosting(),
+ defaultConvertPathToInterval,
+ ConstEval::constFold,
+ std::move(debugInfo),
+ std::move(queryHints)};
+}
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.h b/src/mongo/db/query/optimizer/utils/unit_test_utils.h
index b0ad4704b84..e358dc13495 100644
--- a/src/mongo/db/query/optimizer/utils/unit_test_utils.h
+++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.h
@@ -31,6 +31,7 @@
#include "mongo/db/bson/dotted_path_support.h"
#include "mongo/db/query/optimizer/defs.h"
+#include "mongo/db/query/optimizer/opt_phase_manager.h"
#include "mongo/db/query/optimizer/utils/utils.h"
@@ -38,6 +39,11 @@ namespace mongo::optimizer {
void maybePrintABT(const ABT& abt);
+bool handleAutoUpdate(const std::string& expected,
+ const std::string& actual,
+ const std::string& fileName,
+ size_t lineNumber);
+
#define ASSERT_EXPLAIN(expected, abt) \
maybePrintABT(abt); \
ASSERT_EQ(expected, ExplainGenerator::explain(abt))
@@ -46,13 +52,36 @@ void maybePrintABT(const ABT& abt);
maybePrintABT(abt); \
ASSERT_EQ(expected, ExplainGenerator::explainV2(abt))
+/**
+ * Auto update result back in the source file if the assert fails.
+ * The expected result must be a multi-line string in the following form:
+ *
+ * ASSERT_EXPLAIN_V2_AUTO( // NOLINT
+ * "BinaryOp [Add]\n"
+ * "| Const [2]\n"
+ * "Const [1]\n",
+ * tree);
+ *
+ * Limitations:
+ * 1. There should not be any comments or other formatting inside the multi-line string
+ * constant other than 'NOLINT'. If we have a single-line constant, the auto-updating will
+ * generate a 'NOLINT' at the end of the line.
+ * 2. The expression which we are explaining ('tree' in the example above) must fit on a single
+ * line. The macro should be indented by 4 spaces.
+ *
+ * TODO: SERVER-71004: Extend the usability of the auto-update macro.
+ */
+#define ASSERT_EXPLAIN_V2_AUTO(expected, abt) \
+ maybePrintABT(abt); \
+ ASSERT(handleAutoUpdate(expected, ExplainGenerator::explainV2(abt), __FILE__, __LINE__))
+
#define ASSERT_EXPLAIN_V2Compact(expected, abt) \
maybePrintABT(abt); \
ASSERT_EQ(expected, ExplainGenerator::explainV2Compact(abt))
#define ASSERT_EXPLAIN_BSON(expected, abt) \
maybePrintABT(abt); \
- ASSERT_EQ(expected, ExplainGenerator::explainBSON(abt))
+ ASSERT_EQ(expected, ExplainGenerator::explainBSONStr(abt))
#define ASSERT_EXPLAIN_PROPS_V2(expected, phaseManager) \
ASSERT_EQ(expected, \
@@ -93,4 +122,37 @@ IndexDefinition makeIndexDefinition(FieldNameType fieldName,
IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields,
bool isMultiKey = true);
+/**
+ * A convenience factory function to create costing.
+ */
+std::unique_ptr<CostingInterface> makeCosting();
+
+/**
+ * A convenience factory function to create OptPhaseManager for unit tests.
+ */
+OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ DebugInfo debugInfo,
+ QueryHints queryHints = {});
+
+/**
+ * A convenience factory function to create OptPhaseManager for unit tests with CE hints.
+ */
+OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ std::unique_ptr<CEInterface> ceDerivation,
+ DebugInfo debugInfo,
+ QueryHints queryHints = {});
+
+/**
+ * A convenience factory function to create OptPhaseManager for unit tests which requires RID.
+ */
+OptPhaseManager makePhaseManagerRequireRID(OptPhaseManager::PhaseSet phaseSet,
+ PrefixId& prefixId,
+ Metadata metadata,
+ DebugInfo debugInfo,
+ QueryHints queryHints = {});
+
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/utils.cpp b/src/mongo/db/query/optimizer/utils/utils.cpp
index 6c2dac94c96..a676e28179b 100644
--- a/src/mongo/db/query/optimizer/utils/utils.cpp
+++ b/src/mongo/db/query/optimizer/utils/utils.cpp
@@ -102,6 +102,7 @@ properties::LogicalProps createInitialScanProps(const ProjectionName& projection
projectionName,
scanDefName,
true /*eqPredsOnly*/,
+ false /*hasProperInterval*/,
{} /*satisfiedPartialIndexes*/),
properties::CollectionAvailability({scanDefName}),
properties::DistributionAvailability(std::move(distributions)));
@@ -1928,4 +1929,14 @@ bool pathEndsInTraverse(const optimizer::ABT& path) {
return optimizer::algebra::transport<false>(path, t);
}
+bool hasProperIntervals(const PartialSchemaRequirements& reqMap) {
+ // Compute if this node has any proper (not fully open) intervals.
+ for (const auto& [key, req] : reqMap) {
+ if (!isIntervalReqFullyOpenDNF(req.getIntervals())) {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/utils/utils.h b/src/mongo/db/query/optimizer/utils/utils.h
index e4a2b2367a6..48589995a14 100644
--- a/src/mongo/db/query/optimizer/utils/utils.h
+++ b/src/mongo/db/query/optimizer/utils/utils.h
@@ -347,4 +347,5 @@ ABT lowerIntervals(PrefixId& prefixId,
*/
bool pathEndsInTraverse(const optimizer::ABT& path);
+bool hasProperIntervals(const PartialSchemaRequirements& reqMap);
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/parsed_distinct.cpp b/src/mongo/db/query/parsed_distinct.cpp
index 21781211771..03f2e3bf9e9 100644
--- a/src/mongo/db/query/parsed_distinct.cpp
+++ b/src/mongo/db/query/parsed_distinct.cpp
@@ -322,7 +322,8 @@ StatusWith<ParsedDistinct> ParsedDistinct::parse(OperationContext* opCtx,
return ParsedDistinct(std::move(cq.getValue()),
parsedDistinct.getKey().toString(),
- parsedDistinct.getMirrored().value_or(false));
+ parsedDistinct.getMirrored().value_or(false),
+ parsedDistinct.getSampleId());
}
} // namespace mongo
diff --git a/src/mongo/db/query/parsed_distinct.h b/src/mongo/db/query/parsed_distinct.h
index b5cabef4680..49809d98440 100644
--- a/src/mongo/db/query/parsed_distinct.h
+++ b/src/mongo/db/query/parsed_distinct.h
@@ -55,8 +55,12 @@ public:
ParsedDistinct(std::unique_ptr<CanonicalQuery> query,
const std::string key,
- const bool mirrored = false)
- : _query(std::move(query)), _key(std::move(key)), _mirrored(std::move(mirrored)) {}
+ const bool mirrored = false,
+ const boost::optional<UUID> sampleId = boost::none)
+ : _query(std::move(query)),
+ _key(std::move(key)),
+ _mirrored(std::move(mirrored)),
+ _sampleId(std::move(sampleId)) {}
const CanonicalQuery* getQuery() const {
return _query.get();
@@ -74,6 +78,10 @@ public:
return _key;
}
+ boost::optional<UUID> getSampleId() const {
+ return _sampleId;
+ }
+
bool isMirrored() const {
return _mirrored;
}
@@ -102,6 +110,9 @@ private:
// Indicates that this was a mirrored operation.
bool _mirrored = false;
+
+ // The unique sample id for this operation if it has been chosen for sampling.
+ boost::optional<UUID> _sampleId;
};
} // namespace mongo
diff --git a/src/mongo/db/query/plan_cache_key_factory.cpp b/src/mongo/db/query/plan_cache_key_factory.cpp
index 88d8f8a3665..c3a01418819 100644
--- a/src/mongo/db/query/plan_cache_key_factory.cpp
+++ b/src/mongo/db/query/plan_cache_key_factory.cpp
@@ -106,7 +106,7 @@ boost::optional<Timestamp> computeNewestVisibleIndexTimestamp(OperationContext*
Timestamp currentNewestVisible = Timestamp::min();
auto ii = collection->getIndexCatalog()->getIndexIterator(
- opCtx, IndexCatalog::InclusionPolicy::kReady | IndexCatalog::InclusionPolicy::kUnfinished);
+ opCtx, IndexCatalog::InclusionPolicy::kReady);
while (ii->more()) {
const IndexCatalogEntry* ice = ii->next();
auto minVisibleSnapshot = ice->getMinimumVisibleSnapshot();
diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl
index fb8d93b9027..869ecbd1695 100644
--- a/src/mongo/db/query/query_knobs.idl
+++ b/src/mongo/db/query/query_knobs.idl
@@ -963,6 +963,15 @@ server_parameters:
validator:
callback: telemetry_util::validateTelemetryCacheSize
+ internalQueryDisableExclusionProjectionFastPath:
+ description: "If true, then ExclusionProjectionExecutor won't use fast path implementation. This
+ is needed to pass generational fuzzers that are sensitive to field order and other corner cases
+ when switching from Document to BSONObj."
+ set_at: [ startup ]
+ cpp_varname: "internalQueryDisableExclusionProjectionFastPath"
+ cpp_vartype: bool
+ default: false
+ test_only: true
# Note for adding additional query knobs:
#
diff --git a/src/mongo/db/query/query_planner.cpp b/src/mongo/db/query/query_planner.cpp
index 6902650a1ef..a3a63e66449 100644
--- a/src/mongo/db/query/query_planner.cpp
+++ b/src/mongo/db/query/query_planner.cpp
@@ -399,6 +399,11 @@ StatusWith<std::unique_ptr<QuerySolution>> tryToBuildColumnScan(
}
return Status{ErrorCodes::Error{6298502}, "columnstore index is not applicable for this query"};
}
+
+bool collscanIsBounded(const CollectionScanNode* collscan) {
+ return collscan->minRecord || collscan->maxRecord;
+}
+
} // namespace
using std::numeric_limits;
@@ -615,13 +620,23 @@ static BSONObj finishMaxObj(const IndexEntry& indexEntry,
}
}
+std::pair<std::unique_ptr<QuerySolution>, const CollectionScanNode*> buildCollscanSolnWithNode(
+ const CanonicalQuery& query,
+ bool tailable,
+ const QueryPlannerParams& params,
+ int direction = 1) {
+ std::unique_ptr<QuerySolutionNode> solnRoot(
+ QueryPlannerAccess::makeCollectionScan(query, tailable, params, direction));
+ const auto* collscanNode = checked_cast<const CollectionScanNode*>(solnRoot.get());
+ return std::make_pair(
+ QueryPlannerAnalysis::analyzeDataAccess(query, params, std::move(solnRoot)), collscanNode);
+}
+
std::unique_ptr<QuerySolution> buildCollscanSoln(const CanonicalQuery& query,
bool tailable,
const QueryPlannerParams& params,
int direction = 1) {
- std::unique_ptr<QuerySolutionNode> solnRoot(
- QueryPlannerAccess::makeCollectionScan(query, tailable, params, direction));
- return QueryPlannerAnalysis::analyzeDataAccess(query, params, std::move(solnRoot));
+ return buildCollscanSolnWithNode(query, tailable, params, direction).first;
}
std::unique_ptr<QuerySolution> buildWholeIXSoln(
@@ -1518,6 +1533,8 @@ StatusWith<std::vector<std::unique_ptr<QuerySolution>>> QueryPlanner::plan(
"No indexed plans available, and running with 'notablescan'");
}
+ bool clusteredCollection = params.clusteredInfo.has_value();
+
// geoNear and text queries *require* an index.
// Also, if a hint is specified it indicates that we MUST use it.
bool possibleToCollscan =
@@ -1527,21 +1544,23 @@ StatusWith<std::vector<std::unique_ptr<QuerySolution>>> QueryPlanner::plan(
return Status(ErrorCodes::NoQueryExecutionPlans, "No query solutions");
}
- if (possibleToCollscan && (collscanRequested || collScanRequired)) {
- auto collscan = buildCollscanSoln(query, isTailable, params);
- if (!collscan && collScanRequired) {
+ if (possibleToCollscan && (collscanRequested || collScanRequired || clusteredCollection)) {
+ auto [collscanSoln, collscanNode] = buildCollscanSolnWithNode(query, isTailable, params);
+ if (!collscanSoln && collScanRequired) {
return Status(ErrorCodes::NoQueryExecutionPlans,
"Failed to build collection scan soln");
}
- if (collscan) {
+
+ if (collscanSoln &&
+ (collscanRequested || collScanRequired || collscanIsBounded(collscanNode))) {
LOGV2_DEBUG(20984,
5,
"Planner: outputting a collection scan",
- "collectionScan"_attr = redact(collscan->toString()));
+ "collectionScan"_attr = redact(collscanSoln->toString()));
SolutionCacheData* scd = new SolutionCacheData();
scd->solnType = SolutionCacheData::COLLSCAN_SOLN;
- collscan->cacheData.reset(scd);
- out.push_back(std::move(collscan));
+ collscanSoln->cacheData.reset(scd);
+ out.push_back(std::move(collscanSoln));
}
}
diff --git a/src/mongo/db/query/query_solution.h b/src/mongo/db/query/query_solution.h
index a24cff08f2f..d33794340fe 100644
--- a/src/mongo/db/query/query_solution.h
+++ b/src/mongo/db/query/query_solution.h
@@ -1396,15 +1396,15 @@ struct GroupNode : public QuerySolutionNode {
shouldProduceBson(shouldProduceBson) {
// Use the DepsTracker to extract the fields that the 'groupByExpression' and accumulator
// expressions depend on.
- for (auto& groupByExprField : expression::getDependencies(groupByExpression.get()).fields) {
- requiredFields.insert(groupByExprField);
- }
+ DepsTracker deps;
+ expression::addDependencies(groupByExpression.get(), &deps);
for (auto&& acc : accumulators) {
- auto argExpr = acc.expr.argument;
- for (auto& argExprField : expression::getDependencies(argExpr.get()).fields) {
- requiredFields.insert(argExprField);
- }
+ expression::addDependencies(acc.expr.argument.get(), &deps);
}
+
+ requiredFields = deps.fields;
+ needWholeDocument = deps.needWholeDocument;
+ needsAnyMetadata = deps.getNeedsAnyMetadata();
}
StageType getType() const override {
@@ -1438,7 +1438,9 @@ struct GroupNode : public QuerySolutionNode {
// Carries the fields this GroupNode depends on. Namely, 'requiredFields' contains the union of
// the fields in the 'groupByExpressions' and the fields in the input Expressions of the
// 'accumulators'.
- StringSet requiredFields;
+ OrderedPathSet requiredFields;
+ bool needWholeDocument = false;
+ bool needsAnyMetadata = false;
// If set to true, generated SBE plan will produce result as BSON object. If false,
// 'sbe::Object' is produced instead.
diff --git a/src/mongo/db/query/sbe_shard_filter_test.cpp b/src/mongo/db/query/sbe_shard_filter_test.cpp
index 82585905f6e..a5d49bf4849 100644
--- a/src/mongo/db/query/sbe_shard_filter_test.cpp
+++ b/src/mongo/db/query/sbe_shard_filter_test.cpp
@@ -234,7 +234,7 @@ TEST_F(SbeShardFilterTest, MissingFieldsAtBottomDottedPathFilledCorrectly) {
TEST_F(SbeShardFilterTest, CoveredShardFilterPlan) {
auto indexKeyPattern = BSON("a" << 1 << "b" << 1 << "c" << 1 << "d" << 1);
- auto projection = BSON("a" << 1 << "c" << 1);
+ auto projection = BSON("a" << 1 << "c" << 1 << "_id" << 0);
auto mockedIndexKeys =
std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 2 << "d" << 2)),
BSON_ARRAY(BSON("a" << 3 << "b" << 3 << "c" << 3 << "d" << 3))};
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index ce8344a0d36..a196821e3dc 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -84,129 +84,9 @@
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery
-
namespace mongo::stage_builder {
namespace {
/**
- * For covered projections, each of the projection field paths represent respective index key. To
- * rehydrate index keys into the result object, we first need to convert projection AST into
- * 'IndexKeyPatternTreeNode' structure. Context structure and visitors below are used for this
- * purpose.
- */
-struct IndexKeysBuilderContext {
- // Contains resulting tree of index keys converted from projection AST.
- IndexKeyPatternTreeNode root;
-
- // Full field path of the currently visited projection node.
- std::vector<StringData> currentFieldPath;
-
- // Each projection node has a vector of field names. This stack contains indexes of the
- // currently visited field names for each of the projection nodes.
- std::vector<size_t> currentFieldIndex;
-};
-
-/**
- * Covered projections are always inclusion-only, so we ban all other operators.
- */
-class IndexKeysBuilder : public projection_ast::ProjectionASTConstVisitor {
-public:
- using projection_ast::ProjectionASTConstVisitor::visit;
-
- IndexKeysBuilder(IndexKeysBuilderContext* context) : _context{context} {}
-
- void visit(const projection_ast::ProjectionPositionalASTNode* node) final {
- tasserted(5474501, "Positional projection is not allowed in covered projection");
- }
-
- void visit(const projection_ast::ProjectionSliceASTNode* node) final {
- tasserted(5474502, "$slice is not allowed in covered projection");
- }
-
- void visit(const projection_ast::ProjectionElemMatchASTNode* node) final {
- tasserted(5474503, "$elemMatch is not allowed in covered projection");
- }
-
- void visit(const projection_ast::ExpressionASTNode* node) final {
- tasserted(5474504, "Expressions are not allowed in covered projection");
- }
-
- void visit(const projection_ast::MatchExpressionASTNode* node) final {
- tasserted(5474505,
- "$elemMatch and positional projection are not allowed in covered projection");
- }
-
- void visit(const projection_ast::BooleanConstantASTNode* node) override {}
-
-protected:
- IndexKeysBuilderContext* _context;
-};
-
-class IndexKeysPreBuilder final : public IndexKeysBuilder {
-public:
- using IndexKeysBuilder::IndexKeysBuilder;
- using IndexKeysBuilder::visit;
-
- void visit(const projection_ast::ProjectionPathASTNode* node) final {
- _context->currentFieldIndex.push_back(0);
- _context->currentFieldPath.emplace_back(node->fieldNames().front());
- }
-};
-
-class IndexKeysInBuilder final : public IndexKeysBuilder {
-public:
- using IndexKeysBuilder::IndexKeysBuilder;
- using IndexKeysBuilder::visit;
-
- void visit(const projection_ast::ProjectionPathASTNode* node) final {
- auto& currentIndex = _context->currentFieldIndex.back();
- currentIndex++;
- _context->currentFieldPath.back() = node->fieldNames()[currentIndex];
- }
-};
-
-class IndexKeysPostBuilder final : public IndexKeysBuilder {
-public:
- using IndexKeysBuilder::IndexKeysBuilder;
- using IndexKeysBuilder::visit;
-
- void visit(const projection_ast::ProjectionPathASTNode* node) final {
- _context->currentFieldIndex.pop_back();
- _context->currentFieldPath.pop_back();
- }
-
- void visit(const projection_ast::BooleanConstantASTNode* constantNode) final {
- if (!constantNode->value()) {
- // Even though only inclusion is allowed in covered projection, there still can be
- // {_id: 0} component. We do not need to generate any nodes for it.
- return;
- }
-
- // Insert current field path into the index keys tree if it does not exist yet.
- auto* node = &_context->root;
- for (const auto& part : _context->currentFieldPath) {
- if (auto it = node->children.find(part); it != node->children.end()) {
- node = it->second.get();
- } else {
- node = node->emplace(part);
- }
- }
- }
-};
-
-sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs, const PlanStageSlots& outputs) {
- std::vector<std::pair<StringData, sbe::value::SlotId>> pairs;
- outputs.forEachSlot(
- reqs, [&](auto&& slot, const StringData& name) { pairs.emplace_back(name, slot); });
- std::sort(pairs.begin(), pairs.end());
-
- auto outputSlots = sbe::makeSV();
- for (auto&& p : pairs) {
- outputSlots.emplace_back(p.second);
- }
- return outputSlots;
-}
-
-/**
* Generates an EOF plan. Note that even though this plan will return nothing, it will still define
* the slots specified by 'reqs'.
*/
@@ -271,6 +151,26 @@ std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment(
return env;
}
+sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs,
+ const PlanStageSlots& outputs,
+ const sbe::value::SlotVector& exclude) {
+ auto excludeSet = sbe::value::SlotSet{exclude.begin(), exclude.end()};
+
+ std::vector<std::pair<PlanStageSlots::Name, sbe::value::SlotId>> pairs;
+ outputs.forEachSlot(reqs, [&](auto&& slot, const PlanStageSlots::Name& name) {
+ if (!excludeSet.count(slot)) {
+ pairs.emplace_back(name, slot);
+ }
+ });
+ std::sort(pairs.begin(), pairs.end());
+
+ auto outputSlots = sbe::makeSV();
+ for (auto&& p : pairs) {
+ outputSlots.emplace_back(p.second);
+ }
+ return outputSlots;
+}
+
void prepareSlotBasedExecutableTree(OperationContext* opCtx,
sbe::PlanStage* root,
PlanStageData* data,
@@ -524,22 +424,22 @@ std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolution
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildCollScan(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023400, "buildCollScan() does not support kKey", !reqs.hasKeys());
+ auto fields = reqs.getFields();
auto csn = static_cast<const CollectionScanNode*>(root);
auto [stage, outputs] = generateCollScan(_state,
getCurrentCollection(reqs),
csn,
+ fields,
_yieldPolicy,
reqs.getIsTailableCollScanResumeBranch());
if (reqs.has(kReturnKey)) {
// Assign the 'returnKeySlot' to be the empty object.
outputs.set(kReturnKey, _slotIdGenerator.generate());
- stage = sbe::makeProjectStage(std::move(stage),
- root->nodeId(),
- outputs.get(kReturnKey),
- sbe::makeE<sbe::EFunction>("newObj", sbe::makeEs()));
+ stage = sbe::makeProjectStage(
+ std::move(stage), root->nodeId(), outputs.get(kReturnKey), makeFunction("newObj"_sd));
}
// Don't advertize the RecordId output if none of our ancestors are going to use it.
if (!reqs.has(kRecordId)) {
@@ -552,12 +452,20 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildVirtualScan(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
using namespace std::literals;
+
auto vsn = static_cast<const VirtualScanNode*>(root);
- // The caller should only have requested components of the index key if the virtual scan is
- // mocking an index scan.
- if (vsn->scanType == VirtualScanNode::ScanType::kCollScan) {
- invariant(!reqs.getIndexKeyBitset());
- }
+ auto reqKeys = reqs.getKeys();
+
+ // The caller should only request kKey slots if the virtual scan is mocking an index scan.
+ tassert(6023401,
+ "buildVirtualScan() does not support kKey when 'scanType' is not ixscan",
+ vsn->scanType == VirtualScanNode::ScanType::kIxscan || reqKeys.empty());
+
+ tassert(6023423,
+ "buildVirtualScan() does not support dotted paths for kKey slots",
+ std::all_of(reqKeys.begin(), reqKeys.end(), [](auto&& s) {
+ return s.find('.') == std::string::npos;
+ }));
auto [inputTag, inputVal] = sbe::value::makeNewArray();
sbe::value::ValueGuard inputGuard{inputTag, inputVal};
@@ -572,7 +480,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
}
inputGuard.reset();
- auto [scanSlots, stage] =
+ auto [scanSlots, scanStage] =
generateVirtualScanMulti(&_slotIdGenerator, vsn->hasRecordId ? 2 : 1, inputTag, inputVal);
sbe::value::SlotId resultSlot;
@@ -586,41 +494,27 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
PlanStageSlots outputs;
- if (reqs.has(kResult)) {
+ if (reqs.has(kResult) || reqs.hasFields()) {
outputs.set(kResult, resultSlot);
- } else if (reqs.getIndexKeyBitset()) {
- // The caller wanted individual slots for certain components of a mock index scan. Use a
- // project stage to produce those slots. Since the test will represent index keys as BSON
- // objects, we use 'getField' expressions to extract the necessary fields.
- invariant(!vsn->indexKeyPattern.isEmpty());
-
- sbe::value::SlotVector indexKeySlots;
- sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projections;
-
- size_t indexKeyPos = 0;
- for (auto&& field : vsn->indexKeyPattern) {
- if (reqs.getIndexKeyBitset()->test(indexKeyPos)) {
- indexKeySlots.push_back(_slotIdGenerator.generate());
- projections.emplace(indexKeySlots.back(),
- makeFunction("getField"_sd,
- sbe::makeE<sbe::EVariable>(resultSlot),
- makeConstant(field.fieldName())));
- }
- ++indexKeyPos;
- }
-
- stage =
- sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projections), root->nodeId());
-
- outputs.setIndexKeySlots(indexKeySlots);
}
-
if (reqs.has(kRecordId)) {
invariant(vsn->hasRecordId);
invariant(scanSlots.size() == 2);
outputs.set(kRecordId, scanSlots[0]);
}
+ auto stage = std::move(scanStage);
+
+ // The caller wants individual slots for certain components of the mock index scan. Retrieve
+ // the values for these paths and project them to slots.
+ auto [projectStage, slots] = projectTopLevelFields(
+ std::move(stage), reqKeys, resultSlot, root->nodeId(), &_slotIdGenerator);
+ stage = std::move(projectStage);
+
+ for (size_t i = 0; i < reqKeys.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kKey, std::move(reqKeys[i])), slots[i]);
+ }
+
return {std::move(stage), std::move(outputs)};
}
@@ -629,17 +523,35 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto ixn = static_cast<const IndexScanNode*>(root);
invariant(reqs.has(kReturnKey) || !ixn->addKeyMetadata);
- sbe::IndexKeysInclusionSet indexKeyBitset;
+ auto reqKeys = reqs.getKeys();
+ auto reqKeysSet = StringDataSet{reqKeys.begin(), reqKeys.end()};
- if (reqs.has(PlanStageSlots::kReturnKey) || reqs.has(PlanStageSlots::kResult)) {
- // If either 'reqs.result' or 'reqs.returnKey' is true, we need to get all parts of the
- // index key (regardless of what was requested by 'reqs.indexKeyBitset') so that we can
- // create the inflated index key (keyExpr).
- for (int i = 0; i < ixn->index.keyPattern.nFields(); ++i) {
- indexKeyBitset.set(i);
+ std::vector<StringData> keys;
+ sbe::IndexKeysInclusionSet keysBitset;
+ StringDataSet indexKeyPatternSet;
+ size_t i = 0;
+ for (const auto& elt : ixn->index.keyPattern) {
+ StringData name = elt.fieldNameStringData();
+ indexKeyPatternSet.emplace(name);
+ if (reqKeysSet.count(name)) {
+ keysBitset.set(i);
+ keys.emplace_back(name);
}
- } else if (reqs.getIndexKeyBitset()) {
- indexKeyBitset = *reqs.getIndexKeyBitset();
+ ++i;
+ }
+
+ auto additionalKeys =
+ filterVector(reqKeys, [&](const std::string& s) { return !indexKeyPatternSet.count(s); });
+
+ sbe::IndexKeysInclusionSet indexKeyBitset;
+ if (reqs.has(kReturnKey) || reqs.has(kResult) || reqs.hasFields()) {
+ // If either 'reqs.result' or 'reqs.returnKey' or 'reqs.hasFields()' is true, we need to
+ // get all parts of the index key so that we can create the inflated index key.
+ for (int j = 0; j < ixn->index.keyPattern.nFields(); ++j) {
+ indexKeyBitset.set(j);
+ }
+ } else {
+ indexKeyBitset = keysBitset;
}
// If the slots necessary for performing an index consistency check were not requested in
@@ -652,28 +564,32 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
const auto generateIndexScanFunc =
ixn->iets.empty() ? generateIndexScan : generateIndexScanWithDynamicBounds;
- auto&& [stage, outputs] = generateIndexScanFunc(_state,
- getCurrentCollection(reqs),
- ixn,
- indexKeyBitset,
- _yieldPolicy,
- iamMap,
- reqs.has(kIndexKeyPattern));
+ auto&& [scanStage, scanOutputs] = generateIndexScanFunc(_state,
+ getCurrentCollection(reqs),
+ ixn,
+ indexKeyBitset,
+ _yieldPolicy,
+ iamMap,
+ reqs.has(kIndexKeyPattern));
+
+ auto stage = std::move(scanStage);
+ auto outputs = std::move(scanOutputs);
// Remove the RecordId from the output if we were not requested to produce it.
if (!reqs.has(PlanStageSlots::kRecordId) && outputs.has(kRecordId)) {
outputs.clear(kRecordId);
}
- if (reqs.has(PlanStageSlots::kReturnKey)) {
- sbe::EExpression::Vector mkObjArgs;
- size_t i = 0;
+ if (reqs.has(PlanStageSlots::kReturnKey)) {
+ sbe::EExpression::Vector args;
for (auto&& elem : ixn->index.keyPattern) {
- mkObjArgs.emplace_back(sbe::makeE<sbe::EConstant>(elem.fieldNameStringData()));
- mkObjArgs.emplace_back(sbe::makeE<sbe::EVariable>((*outputs.getIndexKeySlots())[i++]));
+ StringData name = elem.fieldNameStringData();
+ args.emplace_back(sbe::makeE<sbe::EConstant>(name));
+ args.emplace_back(
+ makeVariable(outputs.get(std::make_pair(PlanStageSlots::kKey, name))));
}
- auto rawKeyExpr = sbe::makeE<sbe::EFunction>("newObj", std::move(mkObjArgs));
+ auto rawKeyExpr = sbe::makeE<sbe::EFunction>("newObj"_sd, std::move(args));
outputs.set(PlanStageSlots::kReturnKey, _slotIdGenerator.generate());
stage = sbe::makeProjectStage(std::move(stage),
ixn->nodeId(),
@@ -681,25 +597,30 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::move(rawKeyExpr));
}
- if (reqs.has(PlanStageSlots::kResult)) {
- outputs.set(PlanStageSlots::kResult, _slotIdGenerator.generate());
- stage = rehydrateIndexKey(std::move(stage),
- ixn->index.keyPattern,
- ixn->nodeId(),
- *outputs.getIndexKeySlots(),
- outputs.get(PlanStageSlots::kResult));
+ if (reqs.has(kResult) || reqs.hasFields()) {
+ auto indexKeySlots = sbe::makeSV();
+ for (auto&& elem : ixn->index.keyPattern) {
+ StringData name = elem.fieldNameStringData();
+ indexKeySlots.emplace_back(outputs.get(std::make_pair(PlanStageSlots::kKey, name)));
+ }
+
+ auto resultSlot = _slotIdGenerator.generate();
+ outputs.set(kResult, resultSlot);
+
+ stage = rehydrateIndexKey(
+ std::move(stage), ixn->index.keyPattern, ixn->nodeId(), indexKeySlots, resultSlot);
}
- if (reqs.getIndexKeyBitset()) {
- outputs.setIndexKeySlots(
- makeIndexKeyOutputSlotsMatchingParentReqs(ixn->index.keyPattern,
- *reqs.getIndexKeyBitset(),
- indexKeyBitset,
- *outputs.getIndexKeySlots()));
- } else {
- outputs.setIndexKeySlots(boost::none);
+ auto [outStage, nothingSlots] = projectNothingToSlots(
+ std::move(stage), additionalKeys.size(), root->nodeId(), &_slotIdGenerator);
+ stage = std::move(outStage);
+ for (size_t i = 0; i < additionalKeys.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kKey, std::move(additionalKeys[i])),
+ nothingSlots[i]);
}
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
@@ -838,9 +759,7 @@ std::unique_ptr<sbe::EExpression> generatePerColumnFilterExpr(StageBuilderState&
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildColumnScan(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
- tassert(6312404,
- "Unexpected index key bitset provided for column scan stage",
- !reqs.getIndexKeyBitset());
+ tassert(6023403, "buildColumnScan() does not support kKey", !reqs.hasKeys());
auto csn = static_cast<const ColumnIndexScanNode*>(root);
tassert(6312405,
@@ -976,19 +895,22 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
auto fn = static_cast<const FetchNode*>(root);
- // The child must produce all of the slots required by the parent of this FetchNode, except for
- // 'resultSlot' which will be produced by the call to makeLoopJoinForFetch() below. In addition
- // to that, the child must always produce a 'recordIdSlot' because it's needed for the call to
- // makeLoopJoinForFetch() below.
+ // The child must produce a kRecordId slot, as well as all the kMeta and kKey slots required
+ // by the parent of this FetchNode except for 'resultSlot'. Note that the child does _not_
+ // need to produce any kField slots. Any kField requests by the parent will be handled by the
+ // logic below.
+ auto child = fn->children[0].get();
+
auto childReqs = reqs.copy()
.clear(kResult)
+ .clearAllFields()
.set(kRecordId)
.set(kSnapshotId)
.set(kIndexId)
.set(kIndexKey)
.set(kIndexKeyPattern);
- auto [stage, outputs] = build(fn->children[0].get(), childReqs);
+ auto [stage, outputs] = build(child, childReqs);
auto iamMap = _data.iamMap;
uassert(4822880, "RecordId slot is not defined", outputs.has(kRecordId));
@@ -999,44 +921,48 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
uassert(5290711, "Index key slot is not defined", outputs.has(kIndexKey));
uassert(5113713, "Index key pattern slot is not defined", outputs.has(kIndexKeyPattern));
- auto forwardingReqs = reqs.copy().clear(kResult).clear(kRecordId);
+ auto forwardingReqs = reqs.copy().clear(kResult).clear(kRecordId).clearAllFields();
auto relevantSlots = getSlotsToForward(forwardingReqs, outputs);
- // Forward slots for components of the index key if our parent requested them.
- if (auto indexKeySlots = outputs.getIndexKeySlots()) {
- relevantSlots.insert(relevantSlots.end(), indexKeySlots->begin(), indexKeySlots->end());
- }
-
- sbe::value::SlotId fetchResultSlot, fetchRecordIdSlot;
- std::tie(fetchResultSlot, fetchRecordIdSlot, stage) =
- makeLoopJoinForFetch(std::move(stage),
- outputs.get(kRecordId),
- outputs.get(kSnapshotId),
- outputs.get(kIndexId),
- outputs.get(kIndexKey),
- outputs.get(kIndexKeyPattern),
- getCurrentCollection(reqs),
- std::move(iamMap),
- root->nodeId(),
- std::move(relevantSlots),
- _slotIdGenerator);
+ auto fields = reqs.getFields();
+
+ auto childRecordId = outputs.get(kRecordId);
+ auto fetchResultSlot = _slotIdGenerator.generate();
+ auto fetchRecordIdSlot = _slotIdGenerator.generate();
+ auto fieldSlots = _slotIdGenerator.generateMultiple(fields.size());
+
+ stage = makeLoopJoinForFetch(std::move(stage),
+ fetchResultSlot,
+ fetchRecordIdSlot,
+ fields,
+ fieldSlots,
+ childRecordId,
+ outputs.get(kSnapshotId),
+ outputs.get(kIndexId),
+ outputs.get(kIndexKey),
+ outputs.get(kIndexKeyPattern),
+ getCurrentCollection(reqs),
+ std::move(iamMap),
+ root->nodeId(),
+ std::move(relevantSlots));
outputs.set(kResult, fetchResultSlot);
- // Propagate the RecordId output only if requested.
+
+ // Only propagate kRecordId if requested.
if (reqs.has(kRecordId)) {
outputs.set(kRecordId, fetchRecordIdSlot);
} else {
outputs.clear(kRecordId);
}
+ for (size_t i = 0; i < fields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]);
+ }
+
if (fn->filter) {
forwardingReqs = reqs.copy().set(kResult);
- auto relevantSlots = getSlotsToForward(forwardingReqs, outputs);
- // Forward slots for components of the index key if our parent requested them.
- if (auto indexKeySlots = outputs.getIndexKeySlots()) {
- relevantSlots.insert(relevantSlots.end(), indexKeySlots->begin(), indexKeySlots->end());
- }
+ auto relevantSlots = getSlotsToForward(forwardingReqs, outputs);
auto [_, outputStage] = generateFilter(_state,
fn->filter.get(),
@@ -1046,6 +972,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
stage = outputStage.extractStage(root->nodeId());
}
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
@@ -1197,22 +1125,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
sortPattern.size() > 0);
auto child = sn->children[0].get();
- if (reqs.getIndexKeyBitset().has_value() ||
- (!reqs.has(kResult) && child->getType() == STAGE_IXSCAN)) {
- // We decide to IndexKeySlots when building the child if:
- // 1) The query is covered; or
- // 2) This is a SORT->IXSCAN and kResult wasn't requested.
+
+ if (auto [ixn, ct] = getFirstNodeByType(root, STAGE_IXSCAN);
+ !sn->fetched() && !reqs.has(kResult) && ixn && ct == 1) {
return buildSortCovered(root, reqs);
}
- // The child must produce the kResult slot as well as all slots required by the parent of
- // this SortNode.
- auto childReqs = reqs.copy().set(kResult);
-
- auto [stage, outputs] = build(child, childReqs);
-
- auto collatorSlot = _data.env->getSlotIfExists("collator"_sd);
-
StringDataSet prefixSet;
bool hasPartsWithCommonPrefix = false;
for (const auto& part : sortPattern) {
@@ -1226,14 +1144,18 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
}
}
+ auto childReqs = reqs.copy().set(kResult);
+ auto [stage, childOutputs] = build(child, childReqs);
+ auto outputs = std::move(childOutputs);
+
+ auto collatorSlot = _data.env->getSlotIfExists("collator"_sd);
+
sbe::value::SlotVector orderBy;
std::vector<sbe::value::SortDirection> direction;
sbe::value::SlotId outputSlotId = outputs.get(kResult);
if (!hasPartsWithCommonPrefix) {
// Handle the case where we are using kResult and there are no common prefixes.
- sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projectMap;
-
orderBy.reserve(sortPattern.size());
// Sorting has a limitation where only one of the sort patterns can involve arrays.
@@ -1329,8 +1251,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
makeConstant(sbe::value::TypeTags::sortSpec,
sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec.release()));
- auto collatorSlot = _data.env->getSlotIfExists("collator"_sd);
-
// generateSortKey() will handle the parallel arrays check and sort key traversal for us,
// so we don't need to generate our own sort key traversal logic in the SBE plan.
stage = sbe::makeProjectStage(std::move(stage),
@@ -1347,7 +1267,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// Slots for sort stage to forward to parent stage. Values in these slots are not used during
// sorting.
- auto forwardedSlots = getSlotsToForward(childReqs, outputs);
+ auto forwardedSlots = getSlotsToForward(childReqs, outputs, orderBy);
stage =
sbe::makeS<sbe::SortStage>(std::move(stage),
@@ -1359,52 +1279,40 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
_cq.getExpCtx()->allowDiskUse,
root->nodeId());
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildSortCovered(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
+ tassert(6023404, "buildSortCovered() does not support kResult", !reqs.has(kResult));
+
const auto sn = static_cast<const SortNode*>(root);
auto sortPattern = SortPattern{sn->pattern, _cq.getExpCtx()};
tassert(7047600,
"QueryPlannerAnalysis should not produce a SortNode with an empty sort pattern",
sortPattern.size() > 0);
+ tassert(6023422, "buildSortCovered() expected 'sn' to not be fetched", !sn->fetched());
- // The child must produce all of the slots required by the parent of this SortNode.
- auto childReqs = reqs.copy();
auto child = sn->children[0].get();
-
- auto parentIndexKeyBitset = reqs.getIndexKeyBitset().get_value_or({});
- BSONObj indexKeyPattern;
- sbe::IndexKeysInclusionSet sortPatternKeyBitset;
- StringMap<size_t> sortSlotPosMap;
-
- // Set IndexKeyBitset to request each part of the index requested by the parent and to request
- // each part of the sort pattern.
auto indexScan = static_cast<const IndexScanNode*>(getLoneNodeByType(child, STAGE_IXSCAN));
tassert(7047601, "Expected index scan below sort", indexScan);
- indexKeyPattern = indexScan->index.keyPattern;
+ auto indexKeyPattern = indexScan->index.keyPattern;
+
+ // The child must produce all of the slots required by the parent of this SortNode.
+ auto childReqs = reqs.copy();
+ std::vector<std::string> keys;
StringDataSet sortPathsSet;
for (const auto& part : sortPattern) {
- sortPathsSet.insert(part.fieldPath->fullPath());
+ const auto& key = part.fieldPath->fullPath();
+ keys.emplace_back(key);
+ sortPathsSet.emplace(key);
}
- // 'sortPatternKeyBitset' will contain a bit pattern indicating which parts of the index
- // pattern are needed for sorting. 'sortPaths' will contain the field paths used in the
- // sort pattern, ordered according to the index pattern.
- std::vector<std::string> sortPaths;
- std::tie(sortPatternKeyBitset, sortPaths) =
- makeIndexKeyInclusionSet(indexKeyPattern, sortPathsSet);
- childReqs.getIndexKeyBitset() = sortPatternKeyBitset | parentIndexKeyBitset;
-
- // Build a map that maps each field path to its position in 'sortPaths'. This will help us
- // later when we need to convert the output of makeIndexKeyOutputSlotsMatchingParentReqs()
- // from the index pattern's order to sort pattern's order.
- for (size_t i = 0; i < sortPaths.size(); ++i) {
- sortSlotPosMap.emplace(std::move(sortPaths[i]), i);
- }
+ childReqs.setKeys(std::move(keys));
auto [stage, outputs] = build(child, childReqs);
@@ -1412,26 +1320,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
sbe::value::SlotVector orderBy;
std::vector<sbe::value::SortDirection> direction;
-
- // Handle the case where we are using IndexKeySlots.
- auto indexKeySlots = *outputs.extractIndexKeySlots();
-
- // Currently, 'indexKeySlots' contains slots for two kinds of index keys:
- // 1. Keys requested for sort pattern
- // 2. Keys requested by parent (if any)
- // The first category of slots needs to go into the 'orderBy' vector. The second category
- // of slots goes into 'outputs' to be used by the parent.
- auto& childIndexKeyBitset = *childReqs.getIndexKeyBitset();
-
- // The query planner does not support covered queries or SORT->IXSCAN plans on array fields
- // for multikey indexes. Therefore we don't have to generate the parallel arrays check or
- // the sort key traversal logic.
- auto sortIndexKeySlots = makeIndexKeyOutputSlotsMatchingParentReqs(
- indexKeyPattern, sortPatternKeyBitset, childIndexKeyBitset, indexKeySlots);
-
- // 'sortIndexKeySlots' is ordered according to the index pattern, but 'orderBy' needs to
- // be ordered according to the sort pattern. We use 'sortIndexKeySlots' to convert from
- // the index pattern's order to the sort pattern's order.
orderBy.reserve(sortPattern.size());
direction.reserve(sortPattern.size());
for (const auto& part : sortPattern) {
@@ -1439,19 +1327,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// contains $meta, so this assertion should always be true.
tassert(7047602, "Sort with $meta is not supported in SBE", part.fieldPath);
- auto it = sortSlotPosMap.find(part.fieldPath->fullPath());
- tassert(6843200,
- str::stream() << "Did not find sort path '" << part.fieldPath->fullPath()
- << "' in sort path map",
- it != sortSlotPosMap.end());
-
- auto slotPos = it->second;
- tassert(6843201,
- str::stream() << "Sort path map for '" << part.fieldPath->fullPath()
- << "' returned an index '" << slotPos << "' that is out of bounds",
- slotPos < sortIndexKeySlots.size());
-
- orderBy.push_back(sortIndexKeySlots[slotPos]);
+ orderBy.push_back(
+ outputs.get(std::make_pair(PlanStageSlots::kKey, part.fieldPath->fullPath())));
direction.push_back(part.isAscending ? sbe::value::SortDirection::Ascending
: sbe::value::SortDirection::Descending);
}
@@ -1478,24 +1355,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// Slots for sort stage to forward to parent stage. Values in these slots are not used during
// sorting.
- auto forwardedSlots = getSlotsToForward(childReqs, outputs);
-
- if (parentIndexKeyBitset.any()) {
- auto parentIndexKeySlots = makeIndexKeyOutputSlotsMatchingParentReqs(
- indexKeyPattern, parentIndexKeyBitset, childIndexKeyBitset, indexKeySlots);
-
- // The 'forwardedSlots' vector should include all slots requested by parent excluding
- // any slots that appear in the 'orderBy' vector.
- sbe::value::SlotSet orderBySlotsSet(orderBy.begin(), orderBy.end());
- for (auto slot : parentIndexKeySlots) {
- if (!orderBySlotsSet.count(slot)) {
- forwardedSlots.push_back(slot);
- }
- }
-
- // Make sure to store all of the slots requested by the parent into 'outputs'.
- outputs.setIndexKeySlots(std::move(parentIndexKeySlots));
- }
+ auto forwardedSlots = getSlotsToForward(childReqs, outputs, orderBy);
stage =
sbe::makeS<sbe::SortStage>(std::move(stage),
@@ -1507,11 +1367,13 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
_cq.getExpCtx()->allowDiskUse,
root->nodeId());
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
-SlotBasedStageBuilder::buildSortKeyGeneraror(const QuerySolutionNode* root,
+SlotBasedStageBuilder::buildSortKeyGenerator(const QuerySolutionNode* root,
const PlanStageReqs& reqs) {
uasserted(4822883, "Sort key generator in not supported in SBE yet");
}
@@ -1534,64 +1396,21 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::vector<sbe::value::SlotVector> inputKeys;
std::vector<sbe::value::SlotVector> inputVals;
- // Children must produce all of the slots required by the parent of this SortMergeNode. In
- // addition, children must always produce a 'recordIdSlot' if the 'dedup' flag is true.
- auto childReqs = reqs.copy().setIf(kRecordId, mergeSortNode->dedup);
+ std::vector<std::string> keys;
+ StringSet sortPatternSet;
+ for (auto&& sortPart : sortPattern) {
+ sortPatternSet.emplace(sortPart.fieldPath->fullPath());
+ keys.emplace_back(sortPart.fieldPath->fullPath());
+ }
- // If a parent node has requested an index key bitset, then we will produce index keys
- // corresponding to the sort pattern parts needed by said parent node.
- bool parentRequestsIdxKeys = reqs.getIndexKeyBitset().has_value();
+ // Children must produce all of the slots required by the parent of this SortMergeNode. In
+ // addition, children must always produce a 'recordIdSlot' if the 'dedup' flag is true, and
+ // they must produce kKey slots for each part of the sort pattern.
+ auto childReqs = reqs.copy().setIf(kRecordId, mergeSortNode->dedup).setKeys(std::move(keys));
for (auto&& child : mergeSortNode->children) {
sbe::value::SlotVector inputKeysForChild;
- // Retrieve the sort pattern provided by the subtree rooted at 'child'. In particular, if
- // our child is a MERGE_SORT, it will provide the sort directly. If instead it's a tree
- // containing a lone IXSCAN, we have to check the key pattern because 'providedSorts()' may
- // or may not provide the baseSortPattern depending on the index bounds (in particular,
- // if the bounds are fixed, the fields will be marked as 'ignored'). Otherwise, we attempt
- // to retrieve it from 'providedSorts'.
- auto childSortPattern = [&]() {
- if (auto [msn, _] = getFirstNodeByType(child.get(), STAGE_SORT_MERGE); msn) {
- auto node = static_cast<const MergeSortNode*>(msn);
- return node->sort;
- } else {
- auto [ixn, ct] = getFirstNodeByType(child.get(), STAGE_IXSCAN);
- if (ixn && ct == 1) {
- auto node = static_cast<const IndexScanNode*>(ixn);
- return node->index.keyPattern;
- }
- }
- auto baseSort = child->providedSorts().getBaseSortPattern();
- tassert(6149600,
- str::stream() << "Did not find sort pattern for child " << child->toString(),
- !baseSort.isEmpty());
- return baseSort;
- }();
-
- // Map of field name to position within the index key. This is used to account for
- // mismatches between the sort pattern and the index key pattern. For instance, suppose
- // the requested sort is {a: 1, b: 1} and the index key pattern is {c: 1, b: 1, a: 1}.
- // When the slots for the relevant components of the index key are generated (i.e.
- // extract keys for 'b' and 'a'), we wish to insert them into 'inputKeys' in the order
- // that they appear in the sort pattern.
- StringMap<size_t> indexKeyPositionMap;
-
- sbe::IndexKeysInclusionSet indexKeyBitset;
- size_t i = 0;
- for (auto&& elt : childSortPattern) {
- for (auto&& sortPart : sortPattern) {
- auto path = sortPart.fieldPath->fullPath();
- if (elt.fieldNameStringData() == path) {
- indexKeyBitset.set(i);
- indexKeyPositionMap.emplace(path, indexKeyPositionMap.size());
- break;
- }
- }
- ++i;
- }
- childReqs.getIndexKeyBitset() = indexKeyBitset;
-
// Children must produce a 'resultSlot' if they produce fetched results.
auto [stage, outputs] = build(child.get(), childReqs);
@@ -1600,30 +1419,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
" if the 'dedup' flag is set",
!mergeSortNode->dedup || outputs.has(kRecordId));
- // Clear the index key bitset after building the child stage.
- childReqs.getIndexKeyBitset() = boost::none;
-
- // Insert the index key slots in the order of the sort pattern.
- auto indexKeys = outputs.extractIndexKeySlots();
- tassert(5184302,
- "SORT_MERGE must receive index key slots as input from its child stages",
- indexKeys);
-
for (const auto& part : sortPattern) {
- auto partPath = part.fieldPath->fullPath();
- auto index = indexKeyPositionMap.find(partPath);
- tassert(5184303,
- str::stream() << "Could not find index key position for sort key part "
- << partPath,
- index != indexKeyPositionMap.end());
- auto indexPos = index->second;
- tassert(5184304,
- str::stream() << "Index position " << indexPos
- << " is not less than number of index components "
- << indexKeys->size(),
- indexPos < indexKeys->size());
- auto indexKeyPart = indexKeys->at(indexPos);
- inputKeysForChild.push_back(indexKeyPart);
+ inputKeysForChild.push_back(
+ outputs.get(std::make_pair(PlanStageSlots::kKey, part.fieldPath->fullPath())));
}
inputKeys.push_back(std::move(inputKeysForChild));
@@ -1631,32 +1429,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto sv = getSlotsToForward(childReqs, outputs);
- // If the parent of 'root' has requested index keys, then we need to pass along our input
- // keys as input values, as they will be part of the output 'root' provides its parent.
- if (parentRequestsIdxKeys) {
- for (auto&& inputKey : inputKeys.back()) {
- sv.push_back(inputKey);
- }
- }
inputVals.push_back(std::move(sv));
}
PlanStageSlots outputs(childReqs, &_slotIdGenerator);
- auto outputVals = getSlotsToForward(childReqs, outputs);
- // If the parent of 'root' has requested index keys, then we need to generate output slots to
- // hold the index keys that will be used as input to the parent of 'root'.
- if (parentRequestsIdxKeys) {
- auto idxKeySv = sbe::makeSV();
- for (int idx = 0; idx < mergeSortNode->sort.nFields(); ++idx) {
- idxKeySv.emplace_back(_slotIdGenerator.generate());
- }
- outputs.setIndexKeySlots(idxKeySv);
-
- for (auto keySlot : idxKeySv) {
- outputVals.push_back(keySlot);
- }
- }
+ auto outputVals = getSlotsToForward(childReqs, outputs);
auto stage = sbe::makeS<sbe::SortedMergeStage>(std::move(inputStages),
std::move(inputKeys),
@@ -1674,6 +1452,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
}
}
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
@@ -1681,48 +1461,63 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
SlotBasedStageBuilder::buildProjectionSimple(const QuerySolutionNode* root,
const PlanStageReqs& reqs) {
using namespace std::literals;
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023405, "buildProjectionSimple() does not support kKey", !reqs.hasKeys());
auto pn = static_cast<const ProjectionNodeSimple*>(root);
- // The child must produce all of the slots required by the parent of this ProjectionNodeSimple.
- // In addition to that, the child must always produce a 'resultSlot' because it's needed by the
- // projection logic below.
- auto childReqs = reqs.copy().set(kResult);
- auto [inputStage, outputs] = build(pn->children[0].get(), childReqs);
+ auto [fields, additionalFields] = splitVector(reqs.getFields(), [&](const std::string& s) {
+ return pn->proj.type() == projection_ast::ProjectType::kInclusion
+ ? pn->proj.getRequiredFields().count(s)
+ : !pn->proj.getExcludedPaths().count(s);
+ });
- const auto childResult = outputs.get(kResult);
+ auto childReqs = reqs.copy().clearAllFields().setFields(std::move(fields));
+ auto [stage, childOutputs] = build(pn->children[0].get(), childReqs);
+ auto outputs = std::move(childOutputs);
- sbe::MakeBsonObjStage::FieldBehavior behaviour;
- const OrderedPathSet* fields;
- if (pn->proj.type() == projection_ast::ProjectType::kInclusion) {
- behaviour = sbe::MakeBsonObjStage::FieldBehavior::keep;
- fields = &pn->proj.getRequiredFields();
- } else {
- behaviour = sbe::MakeBsonObjStage::FieldBehavior::drop;
- fields = &pn->proj.getExcludedPaths();
- }
-
- outputs.set(kResult, _slotIdGenerator.generate());
- inputStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(inputStage),
- outputs.get(kResult),
- childResult,
- behaviour,
- *fields,
- OrderedPathSet{},
- sbe::value::SlotVector{},
- true,
- false,
- root->nodeId());
+ if (reqs.has(kResult)) {
+ const auto childResult = outputs.get(kResult);
+
+ sbe::MakeBsonObjStage::FieldBehavior behaviour;
+ const OrderedPathSet* fields;
+ if (pn->proj.type() == projection_ast::ProjectType::kInclusion) {
+ behaviour = sbe::MakeBsonObjStage::FieldBehavior::keep;
+ fields = &pn->proj.getRequiredFields();
+ } else {
+ behaviour = sbe::MakeBsonObjStage::FieldBehavior::drop;
+ fields = &pn->proj.getExcludedPaths();
+ }
+
+ outputs.set(kResult, _slotIdGenerator.generate());
+ stage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(stage),
+ outputs.get(kResult),
+ childResult,
+ behaviour,
+ *fields,
+ OrderedPathSet{},
+ sbe::value::SlotVector{},
+ true,
+ false,
+ root->nodeId());
+ }
+
+ auto [outStage, nothingSlots] = projectNothingToSlots(
+ std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator);
+ for (size_t i = 0; i < additionalFields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])),
+ nothingSlots[i]);
+ }
+
+ outputs.clearNonRequiredSlots(reqs);
- return {std::move(inputStage), std::move(outputs)};
+ return {std::move(outStage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
SlotBasedStageBuilder::buildProjectionCovered(const QuerySolutionNode* root,
const PlanStageReqs& reqs) {
using namespace std::literals;
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023406, "buildProjectionCovered() does not support kKey", !reqs.hasKeys());
auto pn = static_cast<const ProjectionNodeCovered*>(root);
invariant(pn->proj.isSimple());
@@ -1733,51 +1528,86 @@ SlotBasedStageBuilder::buildProjectionCovered(const QuerySolutionNode* root,
!pn->children[0]->fetched());
// This is a ProjectionCoveredNode, so we will be pulling all the data we need from one index.
- // Prepare a bitset to indicate which parts of the index key we need for the projection.
- StringSet requiredFields = {pn->proj.getRequiredFields().begin(),
- pn->proj.getRequiredFields().end()};
-
- // The child must produce all of the slots required by the parent of this ProjectionNodeSimple,
- // except for 'resultSlot' which will be produced by the MakeBsonObjStage below. In addition to
- // that, the child must produce the index key slots that are needed by this covered projection.
- //
// pn->coveredKeyObj is the "index.keyPattern" from the child (which is either an IndexScanNode
// or DistinctNode). pn->coveredKeyObj lists all the fields that the index can provide, not the
- // fields that the projection wants. requiredFields lists all of the fields that the projection
- // needs. Since this is a covered projection, we're guaranteed that pn->coveredKeyObj contains
- // all of the fields that the projection needs.
- auto childReqs = reqs.copy().clear(kResult);
-
- auto [indexKeyBitset, keyFieldNames] =
- makeIndexKeyInclusionSet(pn->coveredKeyObj, requiredFields);
- childReqs.getIndexKeyBitset() = std::move(indexKeyBitset);
-
- auto [inputStage, outputs] = build(pn->children[0].get(), childReqs);
-
- // Assert that the index scan produced index key slots for this covered projection.
- auto indexKeySlots = *outputs.extractIndexKeySlots();
-
- outputs.set(kResult, _slotIdGenerator.generate());
- inputStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(inputStage),
- outputs.get(kResult),
- boost::none,
- boost::none,
- std::vector<std::string>{},
- std::move(keyFieldNames),
- std::move(indexKeySlots),
- true,
- false,
- root->nodeId());
+ // fields that the projection wants. 'pn->proj.getRequiredFields()' lists all of the fields
+ // that the projection needs. Since this is a simple covered projection, we're guaranteed that
+ // 'pn->proj.getRequiredFields()' is a subset of pn->coveredKeyObj.
+
+ // List out the projected fields in the order they appear in 'coveredKeyObj'.
+ std::vector<std::string> keys;
+ StringDataSet keysSet;
+ for (auto&& elt : pn->coveredKeyObj) {
+ std::string key(elt.fieldNameStringData());
+ if (pn->proj.getRequiredFields().count(key)) {
+ keys.emplace_back(std::move(key));
+ keysSet.emplace(elt.fieldNameStringData());
+ }
+ }
+
+ // The child must produce all of the slots required by the parent of this ProjectionNodeSimple,
+ // except for 'resultSlot' which will be produced by the MakeBsonObjStage below if requested by
+ // the caller. In addition to that, the child must produce the index key slots that are needed
+ // by this covered projection.
+ auto childReqs = reqs.copy().clear(kResult).clearAllFields().setKeys(keys);
+ auto [stage, childOutputs] = build(pn->children[0].get(), childReqs);
+ auto outputs = std::move(childOutputs);
+
+ if (reqs.has(kResult)) {
+ auto indexKeySlots = sbe::makeSV();
+ std::vector<std::string> keyFieldNames;
- return {std::move(inputStage), std::move(outputs)};
+ if (keysSet.count("_id"_sd)) {
+ keyFieldNames.emplace_back("_id"_sd);
+ indexKeySlots.emplace_back(outputs.get(std::make_pair(PlanStageSlots::kKey, "_id"_sd)));
+ }
+
+ for (const auto& key : keys) {
+ if (key != "_id"_sd) {
+ keyFieldNames.emplace_back(key);
+ indexKeySlots.emplace_back(
+ outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(key))));
+ }
+ }
+
+ auto resultSlot = _slotIdGenerator.generate();
+ stage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(stage),
+ resultSlot,
+ boost::none,
+ boost::none,
+ std::vector<std::string>{},
+ std::move(keyFieldNames),
+ std::move(indexKeySlots),
+ true,
+ false,
+ root->nodeId());
+
+ outputs.set(kResult, resultSlot);
+ }
+
+ auto [fields, additionalFields] =
+ splitVector(reqs.getFields(), [&](const std::string& s) { return keysSet.count(s); });
+ for (size_t i = 0; i < fields.size(); ++i) {
+ auto slot = outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(fields[i])));
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slot);
+ }
+
+ auto [outStage, nothingSlots] = projectNothingToSlots(
+ std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator);
+ for (size_t i = 0; i < additionalFields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])),
+ nothingSlots[i]);
+ }
+
+ outputs.clearNonRequiredSlots(reqs);
+
+ return {std::move(outStage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root,
const PlanStageReqs& reqs) {
- tassert(7055400,
- "buildProjectionDefault() does not support index key bitsets",
- !reqs.getIndexKeyBitset());
+ tassert(6023407, "buildProjectionDefault() does not support kKey", !reqs.hasKeys());
auto pn = static_cast<const ProjectionNodeDefault*>(root);
const auto& projection = pn->proj;
@@ -1791,7 +1621,7 @@ SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root,
// The child must produce all of the slots required by the parent of this ProjectionNodeDefault.
// In addition to that, the child must always produce 'kResult' because it's needed by the
// projection logic below.
- auto childReqs = reqs.copy().set(kResult);
+ auto childReqs = reqs.copy().set(kResult).clearAllFields();
auto [stage, outputs] = build(pn->children[0].get(), childReqs);
@@ -1805,8 +1635,10 @@ SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root,
root->nodeId());
stage = resultStage.extractStage(root->nodeId());
-
outputs.set(kResult, resultSlot);
+
+ outputs.clearNonRequiredSlots(reqs);
+
return {std::move(stage), std::move(outputs)};
}
@@ -1814,9 +1646,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
SlotBasedStageBuilder::buildProjectionDefaultCovered(const QuerySolutionNode* root,
const PlanStageReqs& reqs,
const IndexScanNode* ixn) {
- tassert(7055401,
- "buildProjectionDefaultCovered() does not support index key bitsets",
- !reqs.getIndexKeyBitset());
+ tassert(6023408, "buildProjectionDefaultCovered() does not support kKey", !reqs.hasKeys());
auto pn = static_cast<const ProjectionNodeDefault*>(root);
const auto& projection = pn->proj;
@@ -1827,64 +1657,60 @@ SlotBasedStageBuilder::buildProjectionDefaultCovered(const QuerySolutionNode* ro
tassert(
7055403, "buildProjectionDefaultCovered() expected 'pn' to not be fetched", !pn->fetched());
- // Convert projection fieldpaths into the tree of 'IndexKeyPatternTreeNode'.
- IndexKeysBuilderContext context;
- IndexKeysPreBuilder preVisitor{&context};
- IndexKeysInBuilder inVisitor{&context};
- IndexKeysPostBuilder postVisitor{&context};
- projection_ast::ProjectionASTConstWalker walker{&preVisitor, &inVisitor, &postVisitor};
- tree_walker::walk<true, projection_ast::ASTNode>(projection.root(), &walker);
-
- IndexKeyPatternTreeNode patternRoot = std::move(context.root);
-
- // Construct a bitset requesting slots from the underlying index scan. These slots
- // correspond to index keys for projection fieldpaths.
if (!ixn) {
ixn = static_cast<const IndexScanNode*>(getLoneNodeByType(root, STAGE_IXSCAN));
}
+
auto& indexKeyPattern = ixn->index.keyPattern;
- sbe::IndexKeysInclusionSet patternBitSet;
+ auto patternRoot = buildPatternTree(pn->proj);
+ std::vector<std::string> keys;
+ StringDataSet keysSet;
std::vector<IndexKeyPatternTreeNode*> patternNodesForSlots;
- size_t i = 0;
for (const auto& element : indexKeyPattern) {
sbe::MatchPath fieldRef{element.fieldNameStringData()};
// Projection field paths are always leaf nodes. In other words, projection like
// {a: 1, 'a.b': 1} would produce a path collision error.
if (auto node = patternRoot.findLeafNode(fieldRef); node) {
- patternBitSet.set(i);
+ keys.emplace_back(element.fieldNameStringData());
+ keysSet.emplace(element.fieldNameStringData());
patternNodesForSlots.push_back(node);
}
-
- ++i;
}
- // We do not need index scan to restore the entire object. Instead, we will restore only
- // necessary parts of it below.
- auto childReqs = reqs.copy().clear(kResult);
- childReqs.getIndexKeyBitset() = patternBitSet;
+ auto childReqs = reqs.copy().clear(kResult).clearAllFields().setKeys(keys);
auto [stage, outputs] = build(pn->children[0].get(), childReqs);
- auto indexKeySlots = *outputs.extractIndexKeySlots();
+ auto [fields, additionalFields] =
+ splitVector(reqs.getFields(), [&](const std::string& s) { return keysSet.count(s); });
+ for (size_t i = 0; i < fields.size(); ++i) {
+ auto slot = outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(fields[i])));
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slot);
+ }
- // Extract slots corresponding to each of the projection fieldpaths.
- invariant(indexKeySlots.size() == patternNodesForSlots.size());
- for (size_t i = 0; i < indexKeySlots.size(); i++) {
- patternNodesForSlots[i]->indexKeySlot = indexKeySlots[i];
+ if (reqs.has(kResult) || !additionalFields.empty()) {
+ // Extract slots corresponding to each of the projection fieldpaths.
+ for (size_t i = 0; i < keys.size(); i++) {
+ patternNodesForSlots[i]->indexKeySlot =
+ outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(keys[i])));
+ }
+
+ // Finally, build the expression to create object with requested projection fieldpaths.
+ auto resultSlot = _slotIdGenerator.generate();
+ outputs.set(kResult, resultSlot);
+
+ stage = sbe::makeProjectStage(
+ std::move(stage), root->nodeId(), resultSlot, buildNewObjExpr(&patternRoot));
}
- // Finally, build the expression to create object with requested projection fieldpaths.
- auto resultSlot = _slotIdGenerator.generate();
- stage = sbe::makeProjectStage(
- std::move(stage), root->nodeId(), resultSlot, buildNewObjExpr(&patternRoot));
+ outputs.clearNonRequiredSlots(reqs);
- outputs.set(kResult, resultSlot);
return {std::move(stage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildOr(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023409, "buildOr() does not support kKey", !reqs.hasKeys());
sbe::PlanStage::Vector inputStages;
std::vector<sbe::value::SlotVector> inputSlots;
@@ -1941,7 +1767,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto textNode = static_cast<const TextMatchNode*>(root);
const auto& coll = getCurrentCollection(reqs);
tassert(5432212, "no collection object", coll);
- tassert(5432213, "index keys requested for text match node", !reqs.getIndexKeyBitset());
+ tassert(6023410, "buildTextMatch() does not support kKey", !reqs.hasKeys());
tassert(5432215,
str::stream() << "text match node must have one child, but got "
<< root->children.size(),
@@ -1992,7 +1818,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildReturnKey(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023411, "buildReturnKey() does not support kKey", !reqs.hasKeys());
// TODO SERVER-49509: If the projection includes {$meta: "sortKey"}, the result of this stage
// should also include the sort key. Everything else in the projection is ignored.
@@ -2002,7 +1828,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// for 'resultSlot'. In addition to that, the child must always produce a 'returnKeySlot'.
// After build() returns, we take the 'returnKeySlot' produced by the child and store it into
// 'resultSlot' for the parent of this ReturnKeyNode to consume.
- auto childReqs = reqs.copy().clear(kResult).set(kReturnKey);
+ auto childReqs = reqs.copy().clear(kResult).clearAllFields().set(kReturnKey);
auto [stage, outputs] = build(returnKeyNode->children[0].get(), childReqs);
outputs.set(kResult, outputs.get(kReturnKey));
@@ -2020,9 +1846,10 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
auto andHashNode = static_cast<const AndHashNode*>(root);
+ tassert(6023412, "buildAndHash() does not support kKey", !reqs.hasKeys());
tassert(5073711, "need at least two children for AND_HASH", andHashNode->children.size() >= 2);
- auto childReqs = reqs.copy().set(kResult).set(kRecordId);
+ auto childReqs = reqs.copy().set(kResult).set(kRecordId).clearAllFields();
auto outerChild = andHashNode->children[0].get();
auto innerChild = andHashNode->children[1].get();
@@ -2049,13 +1876,13 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto collatorSlot = _data.env->getSlotIfExists("collator"_sd);
// Designate outputs.
- PlanStageSlots outputs(reqs, &_slotIdGenerator);
+ PlanStageSlots outputs;
+
+ outputs.set(kResult, innerResultSlot);
+
if (reqs.has(kRecordId)) {
outputs.set(kRecordId, innerIdSlot);
}
- if (reqs.has(kResult)) {
- outputs.set(kResult, innerResultSlot);
- }
if (reqs.has(kSnapshotId) && innerSnapshotIdSlot) {
auto slot = *innerSnapshotIdSlot;
innerProjectSlots.push_back(slot);
@@ -2071,26 +1898,25 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
innerProjectSlots.push_back(slot);
outputs.set(kIndexKey, slot);
}
-
if (reqs.has(kIndexKeyPattern) && innerIndexKeyPatternSlot) {
auto slot = *innerIndexKeyPatternSlot;
innerProjectSlots.push_back(slot);
outputs.set(kIndexKeyPattern, slot);
}
- auto hashJoinStage = sbe::makeS<sbe::HashJoinStage>(std::move(outerStage),
- std::move(innerStage),
- outerCondSlots,
- outerProjectSlots,
- innerCondSlots,
- innerProjectSlots,
- collatorSlot,
- root->nodeId());
+ auto stage = sbe::makeS<sbe::HashJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ outerCondSlots,
+ outerProjectSlots,
+ innerCondSlots,
+ innerProjectSlots,
+ collatorSlot,
+ root->nodeId());
// If there are more than 2 children, iterate all remaining children and hash
// join together.
for (size_t i = 2; i < andHashNode->children.size(); i++) {
- auto [stage, outputs] = build(andHashNode->children[i].get(), childReqs);
+ auto [childStage, outputs] = build(andHashNode->children[i].get(), childReqs);
tassert(5073714, "outputs must contain kRecordId slot", outputs.has(kRecordId));
tassert(5073715, "outputs must contain kResult slot", outputs.has(kResult));
auto idSlot = outputs.get(kRecordId);
@@ -2100,32 +1926,30 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// The previous HashJoinStage is always set as the inner stage, so that we can reuse the
// innerIdSlot and innerResultSlot that have been designated as outputs.
- hashJoinStage = sbe::makeS<sbe::HashJoinStage>(std::move(stage),
- std::move(hashJoinStage),
- condSlots,
- projectSlots,
- innerCondSlots,
- innerProjectSlots,
- collatorSlot,
- root->nodeId());
- }
- // Stop propagating the RecordId output if none of our ancestors are going to use it.
- if (!reqs.has(kRecordId)) {
- outputs.clear(kRecordId);
+ stage = sbe::makeS<sbe::HashJoinStage>(std::move(childStage),
+ std::move(stage),
+ condSlots,
+ projectSlots,
+ innerCondSlots,
+ innerProjectSlots,
+ collatorSlot,
+ root->nodeId());
}
- return {std::move(hashJoinStage), std::move(outputs)};
+ return {std::move(stage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildAndSorted(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
+ tassert(6023413, "buildAndSorted() does not support kKey", !reqs.hasKeys());
+
auto andSortedNode = static_cast<const AndSortedNode*>(root);
// Need at least two children.
tassert(
5073706, "need at least two children for AND_SORTED", andSortedNode->children.size() >= 2);
- auto childReqs = reqs.copy().set(kResult).set(kRecordId);
+ auto childReqs = reqs.copy().set(kResult).set(kRecordId).clearAllFields();
auto outerChild = andSortedNode->children[0].get();
auto innerChild = andSortedNode->children[1].get();
@@ -2161,13 +1985,14 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto innerKeySlots = sbe::makeSV(innerIdSlot);
auto innerProjectSlots = sbe::makeSV(innerResultSlot);
- PlanStageSlots outputs(reqs, &_slotIdGenerator);
+ // Designate outputs.
+ PlanStageSlots outputs;
+
+ outputs.set(kResult, innerResultSlot);
+
if (reqs.has(kRecordId)) {
outputs.set(kRecordId, innerIdSlot);
}
- if (reqs.has(kResult)) {
- outputs.set(kResult, innerResultSlot);
- }
if (reqs.has(kSnapshotId)) {
auto innerSnapshotSlot = innerOutputs.get(kSnapshotId);
innerProjectSlots.push_back(innerSnapshotSlot);
@@ -2183,7 +2008,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
innerProjectSlots.push_back(innerIndexKeySlot);
outputs.set(kIndexKey, innerIndexKeySlot);
}
-
if (reqs.has(kIndexKeyPattern)) {
auto innerIndexKeyPatternSlot = innerOutputs.get(kIndexKeyPattern);
innerProjectSlots.push_back(innerIndexKeyPatternSlot);
@@ -2193,19 +2017,19 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::vector<sbe::value::SortDirection> sortDirs(outerKeySlots.size(),
sbe::value::SortDirection::Ascending);
- auto mergeJoinStage = sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage),
- std::move(innerStage),
- outerKeySlots,
- outerProjectSlots,
- innerKeySlots,
- innerProjectSlots,
- sortDirs,
- root->nodeId());
+ auto stage = sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage),
+ std::move(innerStage),
+ outerKeySlots,
+ outerProjectSlots,
+ innerKeySlots,
+ innerProjectSlots,
+ sortDirs,
+ root->nodeId());
// If there are more than 2 children, iterate all remaining children and merge
// join together.
for (size_t i = 2; i < andSortedNode->children.size(); i++) {
- auto [stage, outputs] = build(andSortedNode->children[i].get(), childReqs);
+ auto [childStage, outputs] = build(andSortedNode->children[i].get(), childReqs);
tassert(5073709, "outputs must contain kRecordId slot", outputs.has(kRecordId));
tassert(5073710, "outputs must contain kResult slot", outputs.has(kResult));
auto idSlot = outputs.get(kRecordId);
@@ -2213,21 +2037,17 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
auto keySlots = sbe::makeSV(idSlot);
auto projectSlots = sbe::makeSV(resultSlot);
- mergeJoinStage = sbe::makeS<sbe::MergeJoinStage>(std::move(stage),
- std::move(mergeJoinStage),
- keySlots,
- projectSlots,
- innerKeySlots,
- innerProjectSlots,
- sortDirs,
- root->nodeId());
- }
- // Stop propagating the RecordId output if none of our ancestors are going to use it.
- if (!reqs.has(kRecordId)) {
- outputs.clear(kRecordId);
+ stage = sbe::makeS<sbe::MergeJoinStage>(std::move(childStage),
+ std::move(stage),
+ keySlots,
+ projectSlots,
+ innerKeySlots,
+ innerProjectSlots,
+ sortDirs,
+ root->nodeId());
}
- return {std::move(mergeJoinStage), std::move(outputs)};
+ return {std::move(stage), std::move(outputs)};
}
namespace {
@@ -2313,39 +2133,6 @@ void walkAndActOnFieldPaths(Expression* expr, const F& fn) {
}
/**
- * Checks whether all field paths in 'idExpr' and all accumulator expressions are top-level ones.
- */
-bool areAllFieldPathsOptimizable(const boost::intrusive_ptr<Expression>& idExpr,
- const std::vector<AccumulationStatement>& accStmts) {
- auto areFieldPathsOptimizable = true;
-
- auto checkFieldPath = [&](const ExpressionFieldPath* fieldExpr, int32_t nestedCondLevel) {
- // We optimize neither a field path for the top-level document itself (getPathLength() == 1)
- // nor a field path that refers to a variable. We can optimize only top-level fields
- // (getPathLength() == 2).
- //
- // The 'nestedCondLevel' being > 0 means that a field path is refered to below conditional
- // expressions at the parent $group node, when we cannot optimize field path access and
- // therefore, cannot avoid materialization.
- if (nestedCondLevel > 0 || fieldExpr->getFieldPath().getPathLength() != 2 ||
- fieldExpr->isVariableReference()) {
- areFieldPathsOptimizable = false;
- return;
- }
- };
-
- // Checks field paths from the group-by expression.
- walkAndActOnFieldPaths(idExpr.get(), checkFieldPath);
-
- // Checks field paths from the accumulator expressions.
- for (auto&& accStmt : accStmts) {
- walkAndActOnFieldPaths(accStmt.expr.argument.get(), checkFieldPath);
- }
-
- return areFieldPathsOptimizable;
-}
-
-/**
* If there are adjacent $group stages in a pipeline and two $group stages are pushed down together,
* the first $group becomes a child GROUP node and the second $group becomes a parent GROUP node in
* a query solution tree. In the case that all field paths are top-level fields for the parent GROUP
@@ -2362,10 +2149,7 @@ EvalStage optimizeFieldPaths(StageBuilderState& state,
const PlanStageSlots& childOutputs,
PlanNodeId nodeId) {
using namespace fmt::literals;
- auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult);
- // Absent root slot means that we optimized away mkbson stage and so, we need to search
- // top-level field names in child outputs.
- auto searchInChildOutputs = !optionalRootSlot.has_value();
+ auto rootSlot = childOutputs.getIfExists(PlanStageSlots::kResult);
auto retEvalStage = std::move(childEvalStage);
walkAndActOnFieldPaths(expr.get(), [&](const ExpressionFieldPath* fieldExpr, int32_t) {
@@ -2376,24 +2160,10 @@ EvalStage optimizeFieldPaths(StageBuilderState& state,
}
auto fieldPathStr = fieldExpr->getFieldPath().fullPath();
- if (searchInChildOutputs) {
- if (auto optionalFieldPathSlot = childOutputs.getIfExists(fieldPathStr);
- optionalFieldPathSlot) {
- state.preGeneratedExprs.emplace(fieldPathStr, *optionalFieldPathSlot);
- } else {
- // getField('fieldPathStr') on a slot containing a BSONObj would have produced
- // 'Nothing' if mkbson stage removal optimization didn't occur. So, generates a
- // 'Nothing' const expression to simulate such a result.
- state.preGeneratedExprs.emplace(
- fieldPathStr, sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0));
- }
-
- return;
- }
if (!state.preGeneratedExprs.contains(fieldPathStr)) {
auto [curEvalExpr, curEvalStage] = generateExpression(
- state, fieldExpr, std::move(retEvalStage), optionalRootSlot, nodeId);
+ state, fieldExpr, std::move(retEvalStage), rootSlot, nodeId, &childOutputs);
auto [slot, stage] = projectEvalExpr(
std::move(curEvalExpr), std::move(curEvalStage), nodeId, state.slotIdGenerator);
@@ -2418,7 +2188,7 @@ std::pair<EvalExpr, EvalStage> generateGroupByKeyImpl(
optimizeFieldPaths(state, idExpr, std::move(childEvalStage), childOutputs, nodeId);
auto [groupByEvalExpr, groupByEvalStage] = stage_builder::generateExpression(
- state, idExpr.get(), std::move(evalStage), optionalRootSlot, nodeId);
+ state, idExpr.get(), std::move(evalStage), optionalRootSlot, nodeId, &childOutputs);
return {std::move(groupByEvalExpr), std::move(groupByEvalStage)};
}
@@ -2430,7 +2200,7 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>>
std::unique_ptr<sbe::PlanStage> childStage,
PlanNodeId nodeId,
sbe::value::SlotIdGenerator* slotIdGenerator) {
- auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult);
+ auto optionalRootSlot = childOutputs.getIfExists(PlanStageSlots::kResult);
EvalStage retEvalStage{std::move(childStage),
optionalRootSlot ? sbe::value::SlotVector{*optionalRootSlot}
: sbe::value::SlotVector{}};
@@ -2510,10 +2280,10 @@ std::tuple<sbe::value::SlotVector, EvalStage> generateAccumulator(
PlanNodeId nodeId,
sbe::value::SlotIdGenerator* slotIdGenerator,
sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>>& accSlotToExprMap) {
- // Input fields may need field traversal which ends up being a complex tree.
+ // Input fields may need field traversal.
auto evalStage = optimizeFieldPaths(
state, accStmt.expr.argument, std::move(childEvalStage), childOutputs, nodeId);
- auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult);
+ auto optionalRootSlot = childOutputs.getIfExists(PlanStageSlots::kResult);
auto [argExpr, accArgEvalStage] = stage_builder::buildArgument(
state, accStmt, std::move(evalStage), optionalRootSlot, nodeId);
@@ -2642,6 +2412,7 @@ sbe::value::SlotVector dedupGroupBySlots(const sbe::value::SlotVector& groupBySl
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildGroup(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
using namespace fmt::literals;
+ tassert(6023414, "buildGroup() does not support kKey", !reqs.hasKeys());
auto groupNode = static_cast<const GroupNode*>(root);
auto nodeId = groupNode->nodeId();
@@ -2657,14 +2428,20 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
const auto& childNode = groupNode->children[0].get();
const auto& accStmts = groupNode->accumulators;
- auto childStageType = childNode->getType();
- auto childReqs = reqs.copy();
- if (childStageType == StageType::STAGE_GROUP && areAllFieldPathsOptimizable(idExpr, accStmts)) {
- // Does not ask the GROUP child for the result slot to avoid unnecessary materialization if
- // all fields are top-level fields. See the end of this function. For example, GROUP - GROUP
- // - COLLSCAN case.
+ auto childReqs = reqs.copy().set(kResult).clearAllFields();
+
+ // Don't ask the GROUP child for the result slot to avoid unnecessary materialization if it's
+ // possible to get everything we need from top-level field slots.
+ if (childNode->getType() == StageType::STAGE_GROUP && !groupNode->needWholeDocument &&
+ !groupNode->needsAnyMetadata) {
childReqs.clear(kResult);
+
+ for (auto&& pathStr : groupNode->requiredFields) {
+ auto path = sbe::MatchPath{pathStr};
+ const auto& topLevelField = path.getPart(0);
+ childReqs.set(std::make_pair(PlanStageSlots::kField, topLevelField));
+ }
}
// Builds the child and gets the child result slot.
@@ -2733,6 +2510,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
aggSlotsVec,
nodeId,
&_slotIdGenerator);
+ auto stage = groupFinalEvalStage.extractStage(nodeId);
tassert(5851605,
"The number of final slots must be as 1 (the final group-by slot) + the number of acc "
@@ -2742,21 +2520,37 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// Cleans up optimized expressions.
_state.preGeneratedExprs.clear();
+ auto fieldNamesSet = StringDataSet{fieldNames.begin(), fieldNames.end()};
+ auto [fields, additionalFields] =
+ splitVector(reqs.getFields(), [&](const std::string& s) { return fieldNamesSet.count(s); });
+ auto fieldsSet = StringDataSet{fields.begin(), fields.end()};
+
PlanStageSlots outputs;
- std::unique_ptr<sbe::PlanStage> outStage;
- // Builds a stage to create a result object out of a group-by slot and gathered accumulator
- // result slots if the parent node requests so. Otherwise, returns field names and associated
- // slots so that a parent stage above can directly refer to a slot by its name because there's
- // no returned object.
+ for (size_t i = 0; i < fieldNames.size(); ++i) {
+ if (fieldsSet.count(fieldNames[i])) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, fieldNames[i]), finalSlots[i]);
+ }
+ };
+
+ auto [outStage, nothingSlots] = projectNothingToSlots(
+ std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator);
+ for (size_t i = 0; i < additionalFields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])),
+ nothingSlots[i]);
+ }
+
+ // Builds a outStage to create a result object out of a group-by slot and gathered accumulator
+ // result slots if the parent node requests so.
if (reqs.has(kResult)) {
- outputs.set(kResult, _slotIdGenerator.generate());
+ auto resultSlot = _slotIdGenerator.generate();
+ outputs.set(kResult, resultSlot);
// This mkbson stage combines 'finalSlots' into a bsonObject result slot which has
// 'fieldNames' fields.
if (groupNode->shouldProduceBson) {
- outStage = sbe::makeS<sbe::MakeBsonObjStage>(groupFinalEvalStage.extractStage(nodeId),
- outputs.get(kResult), // objSlot
- boost::none, // rootSlot
- boost::none, // fieldBehavior
+ outStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(outStage),
+ resultSlot, // objSlot
+ boost::none, // rootSlot
+ boost::none, // fieldBehavior
std::vector<std::string>{}, // fields
std::move(fieldNames), // projectFields
std::move(finalSlots), // projectVars
@@ -2764,8 +2558,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
false, // returnOldObject
nodeId);
} else {
- outStage = sbe::makeS<sbe::MakeObjStage>(groupFinalEvalStage.extractStage(nodeId),
- outputs.get(kResult), // objSlot
+ outStage = sbe::makeS<sbe::MakeObjStage>(std::move(outStage),
+ resultSlot, // objSlot
boost::none, // rootSlot
boost::none, // fieldBehavior
std::vector<std::string>{}, // fields
@@ -2775,12 +2569,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
false, // returnOldObject
nodeId);
}
- } else {
- for (size_t i = 0; i < finalSlots.size(); ++i) {
- outputs.set("CURRENT." + fieldNames[i], finalSlots[i]);
- };
-
- outStage = groupFinalEvalStage.extractStage(nodeId);
}
return {std::move(outStage), std::move(outputs)};
@@ -2790,7 +2578,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* root,
const PlanStageReqs& reqs) {
using namespace std::literals;
- invariant(!reqs.getIndexKeyBitset());
+ tassert(6023415, "makeUnionForTailableCollScan() does not support kKey", !reqs.hasKeys());
// Register a SlotId in the global environment which would contain a recordId to resume a
// tailable collection scan from. A PlanStage executor will track the last seen recordId and
@@ -2814,7 +2602,7 @@ SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* roo
childReqs.setIsTailableCollScanResumeBranch(isTailableCollScanResumeBranch);
auto [branch, outputs] = build(root, childReqs);
- auto branchSlots = getSlotsToForward(reqs, outputs);
+ auto branchSlots = getSlotsToForward(childReqs, outputs);
return {std::move(branchSlots), std::move(branch)};
};
@@ -2877,57 +2665,50 @@ auto buildShardFilterGivenShardKeySlot(sbe::value::SlotId shardKeySlot,
} // namespace
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots>
-SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterNode,
- sbe::value::SlotId shardFiltererSlot,
- BSONObj shardKeyPattern,
- BSONObj indexKeyPattern,
- const QuerySolutionNode* child,
- PlanStageReqs childReqs) {
- StringDataSet shardKeyFields;
- for (auto&& shardKeyElt : shardKeyPattern) {
- shardKeyFields.insert(shardKeyElt.fieldNameStringData());
- }
+SlotBasedStageBuilder::buildShardFilterCovered(const QuerySolutionNode* root,
+ const PlanStageReqs& reqs) {
+ // Constructs an optimized SBE plan for 'filterNode' in the case that the fields of the
+ // 'shardKeyPattern' are provided by 'child'. In this case, the SBE tree for 'child' will
+ // fill out slots for the necessary components of the index key. These slots can be read
+ // directly in order to determine the shard key that should be passed to the
+ // 'shardFiltererSlot'.
+ const auto filterNode = static_cast<const ShardingFilterNode*>(root);
+ auto child = filterNode->children[0].get();
+ tassert(6023416,
+ "buildShardFilterCovered() expects ixscan below shard filter",
+ child->getType() == STAGE_IXSCAN || child->getType() == STAGE_VIRTUAL_SCAN);
- // Save the bit vector describing the fields from the index that our parent requires. The shard
- // filtering process may require additional fields that are not needed by the parent (for
- // example, if the parent is projecting field "a" but the shard key is {a: 1, b: 1}). We will
- // need the parent's reqs later on so that we can hand the correct slot vector for these fields
- // back to our parent.
- auto parentIndexKeyReqs = childReqs.getIndexKeyBitset();
+ // Extract the child's key pattern.
+ BSONObj indexKeyPattern = child->getType() == STAGE_IXSCAN
+ ? static_cast<const IndexScanNode*>(child)->index.keyPattern
+ : static_cast<const VirtualScanNode*>(child)->indexKeyPattern;
- // Determine the set of fields from the index required to obtain the shard key and union those
- // with the set of fields from the index required by the parent stage.
- auto [shardKeyIndexReqs, _] = makeIndexKeyInclusionSet(indexKeyPattern, shardKeyFields);
- const auto ixKeyBitset =
- parentIndexKeyReqs.value_or(sbe::IndexKeysInclusionSet{}) | shardKeyIndexReqs;
- childReqs.getIndexKeyBitset() = ixKeyBitset;
+ auto childReqs = reqs.copy();
+
+ // If we're sharded make sure that we don't return data that isn't owned by the shard. This
+ // situation can occur when pending documents from in-progress migrations are inserted and when
+ // there are orphaned documents from aborted migrations. To check if the document is owned by
+ // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a
+ // BSONObj.
+ auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern();
+ // We register the "shardFilterer" slot but we don't construct the ShardFilterer here, because
+ // once constructed the ShardFilterer will prevent orphaned documents from being deleted. We
+ // will construct the ShardFilterer later while preparing the SBE tree for execution.
+ auto shardFiltererSlot = _data.env->registerSlot(
+ "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator);
+
+ for (auto&& shardKeyElt : shardKeyPattern) {
+ childReqs.set(std::make_pair(PlanStageSlots::kKey, shardKeyElt.fieldNameStringData()));
+ }
auto [stage, outputs] = build(child, childReqs);
- tassert(5562302, "Expected child to produce index key slots", outputs.getIndexKeySlots());
-
- // Maps from key name -> (index in outputs.getIndexKeySlots(), is hashed).
- auto ixKeyPatternFieldToSlotIdx = [&, childOutputs = std::ref(outputs)]() {
- StringDataMap<std::pair<sbe::value::SlotId, bool>> ret;
-
- // Keeps track of which component we're reading in the index key pattern.
- size_t i = 0;
- // Keeps track of the index we are in the slot vector produced by the ix scan. The slot
- // vector produced by the ix scan may be a subset of the key pattern.
- size_t slotIdx = 0;
- for (auto&& ixPatternElt : indexKeyPattern) {
- if (shardKeyFields.count(ixPatternElt.fieldNameStringData())) {
- const bool isHashed = ixPatternElt.valueStringData() == IndexNames::HASHED;
- const auto slotId = (*childOutputs.get().getIndexKeySlots())[slotIdx];
- ret.emplace(ixPatternElt.fieldNameStringData(), std::make_pair(slotId, isHashed));
- }
- if (ixKeyBitset[i]) {
- ++slotIdx;
- }
- ++i;
- }
- return ret;
- }();
+ // Maps from key name to a bool that indicates whether the key is hashed.
+ StringDataMap<bool> indexKeyPatternMap;
+ for (auto&& ixPatternElt : indexKeyPattern) {
+ indexKeyPatternMap.emplace(ixPatternElt.fieldNameStringData(),
+ ShardKeyPattern::isHashedPatternEl(ixPatternElt));
+ }
// Build a project stage to deal with hashed shard keys. This step *could* be skipped if we're
// dealing with non-hashed sharding, but it's done this way for sake of simplicity.
@@ -2935,33 +2716,28 @@ SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterN
sbe::value::SlotVector fieldSlots;
std::vector<std::string> projectFields;
for (auto&& shardKeyPatternElt : shardKeyPattern) {
- auto it = ixKeyPatternFieldToSlotIdx.find(shardKeyPatternElt.fieldNameStringData());
- tassert(5562303, "Could not find element", it != ixKeyPatternFieldToSlotIdx.end());
- auto [slotId, ixKeyEltHashed] = it->second;
+ auto it = indexKeyPatternMap.find(shardKeyPatternElt.fieldNameStringData());
+ tassert(5562303, "Could not find element", it != indexKeyPatternMap.end());
+ const auto ixKeyEltHashed = it->second;
+ const auto slotId = outputs.get(
+ std::make_pair(PlanStageSlots::kKey, shardKeyPatternElt.fieldNameStringData()));
// Get the value stored in the index for this component of the shard key. We may have to
// hash it.
auto elem = makeVariable(slotId);
+ // Handle the case where the index key or shard key is hashed.
const bool shardKeyEltHashed = ShardKeyPattern::isHashedPatternEl(shardKeyPatternElt);
-
- if (shardKeyEltHashed) {
- if (ixKeyEltHashed) {
- // (1) The index stores hashed data and the shard key field is hashed.
- // Nothing to do here. We can apply shard filtering with no other changes.
- } else {
- // (2) The shard key field is hashed but the index stores unhashed data. We must
- // apply the hash function before passing this off to the shard filter.
- elem = makeFunction("shardHash"_sd, std::move(elem));
- }
- } else {
- if (ixKeyEltHashed) {
- // (3) The index stores hashed data but the shard key is not hashed. This is a bug.
- MONGO_UNREACHABLE_TASSERT(5562300);
- } else {
- // (4) The shard key field is not hashed, and the index does not store hashed data.
- // Again, we do nothing here.
- }
+ if (ixKeyEltHashed) {
+ // If the index stores hashed data, then we know the shard key field is hashed as
+ // well. Nothing to do here. We can apply shard filtering with no other changes.
+ tassert(6023421,
+ "Index key is hashed, expected corresponding shard key to be hashed",
+ shardKeyEltHashed);
+ } else if (shardKeyEltHashed) {
+ // The shard key field is hashed but the index stores unhashed data. We must apply
+ // the hash function before passing this off to the shard filter.
+ elem = makeFunction("shardHash"_sd, std::move(elem));
}
fieldSlots.push_back(_slotIdGenerator.generate());
@@ -2987,46 +2763,17 @@ SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterN
auto filterStage = buildShardFilterGivenShardKeySlot(
shardKeySlot, std::move(mkObjStage), shardFiltererSlot, filterNode->nodeId());
- outputs.setIndexKeySlots(!parentIndexKeyReqs ? boost::none
- : boost::optional<sbe::value::SlotVector>{
- makeIndexKeyOutputSlotsMatchingParentReqs(
- indexKeyPattern,
- *parentIndexKeyReqs,
- *childReqs.getIndexKeyBitset(),
- *outputs.getIndexKeySlots())});
+ outputs.clearNonRequiredSlots(reqs);
return {std::move(filterStage), std::move(outputs)};
}
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildShardFilter(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
- const auto filterNode = static_cast<const ShardingFilterNode*>(root);
-
- // If we're sharded make sure that we don't return data that isn't owned by the shard. This
- // situation can occur when pending documents from in-progress migrations are inserted and when
- // there are orphaned documents from aborted migrations. To check if the document is owned by
- // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a
- // BSONObj.
- auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern();
- // We register the "shardFilterer" slot but not construct the ShardFilterer here is because once
- // constructed the ShardFilterer will prevent orphaned documents from being deleted. We will
- // construct the 'ShardFiltered' later while preparing the SBE tree for execution.
- auto shardFiltererSlot = _data.env->registerSlot(
- "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator);
-
- // Determine if our child is an index scan and extract it's key pattern, or empty BSONObj if our
- // child is not an IXSCAN node.
- BSONObj indexKeyPattern = [&]() {
- auto childNode = filterNode->children[0].get();
- switch (childNode->getType()) {
- case StageType::STAGE_IXSCAN:
- return static_cast<const IndexScanNode*>(childNode)->index.keyPattern;
- case StageType::STAGE_VIRTUAL_SCAN:
- return static_cast<const VirtualScanNode*>(childNode)->indexKeyPattern;
- default:
- return BSONObj{};
- }
- }();
+ auto child = root->children[0].get();
+ bool childIsIndexScan = child->getType() == STAGE_IXSCAN ||
+ (child->getType() == STAGE_VIRTUAL_SCAN &&
+ !static_cast<const VirtualScanNode*>(child)->indexKeyPattern.isEmpty());
// If we're not required to fill out the 'kResult' slot, then instead we can request a slot from
// the child for each of the fields which constitute the shard key. This allows us to avoid
@@ -3036,17 +2783,25 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
// We only apply this optimization in the special case that the child QSN is an IXSCAN, since in
// this case we can request exactly the fields we need according to their position in the index
// key pattern.
- auto childReqs = reqs.copy().setIf(kResult, indexKeyPattern.isEmpty());
- if (!childReqs.has(kResult)) {
- return buildShardFilterCovered(filterNode,
- shardFiltererSlot,
- shardKeyPattern,
- std::move(indexKeyPattern),
- filterNode->children[0].get(),
- std::move(childReqs));
+ if (!reqs.has(kResult) && childIsIndexScan) {
+ return buildShardFilterCovered(root, reqs);
}
- auto [stage, outputs] = build(filterNode->children[0].get(), childReqs);
+ auto childReqs = reqs.copy().set(kResult);
+
+ // If we're sharded make sure that we don't return data that isn't owned by the shard. This
+ // situation can occur when pending documents from in-progress migrations are inserted and when
+ // there are orphaned documents from aborted migrations. To check if the document is owned by
+ // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a
+ // BSONObj.
+ auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern();
+ // We register the "shardFilterer" slot but we don't construct the ShardFilterer here, because
+ // once constructed the ShardFilterer will prevent orphaned documents from being deleted. We
+ // will construct the ShardFilterer later while preparing the SBE tree for execution.
+ auto shardFiltererSlot = _data.env->registerSlot(
+ "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator);
+
+ auto [stage, outputs] = build(child, childReqs);
// Build an expression to extract the shard key from the document based on the shard key
// pattern. To do this, we iterate over the shard key pattern parts and build nested 'getField'
@@ -3136,7 +2891,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
{STAGE_SKIP, &SlotBasedStageBuilder::buildSkip},
{STAGE_SORT_SIMPLE, &SlotBasedStageBuilder::buildSort},
{STAGE_SORT_DEFAULT, &SlotBasedStageBuilder::buildSort},
- {STAGE_SORT_KEY_GENERATOR, &SlotBasedStageBuilder::buildSortKeyGeneraror},
+ {STAGE_SORT_KEY_GENERATOR, &SlotBasedStageBuilder::buildSortKeyGenerator},
{STAGE_PROJECTION_SIMPLE, &SlotBasedStageBuilder::buildProjectionSimple},
{STAGE_PROJECTION_DEFAULT, &SlotBasedStageBuilder::buildProjectionDefault},
{STAGE_PROJECTION_COVERED, &SlotBasedStageBuilder::buildProjectionCovered},
@@ -3178,6 +2933,29 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
break;
}
- return std::invoke(kStageBuilders.at(root->getType()), *this, root, reqs);
+ auto [stage, slots] = std::invoke(kStageBuilders.at(root->getType()), *this, root, reqs);
+ auto outputs = std::move(slots);
+
+ auto fields = filterVector(reqs.getFields(), [&](const std::string& s) {
+ return !outputs.has(std::make_pair(PlanStageSlots::kField, StringData(s)));
+ });
+
+ if (!fields.empty()) {
+ tassert(6023424,
+ str::stream() << "Expected build() for " << stageTypeToString(root->getType())
+ << " to either produce a kResult slot or to satisfy all kField reqs",
+ outputs.has(PlanStageSlots::kResult));
+
+ auto resultSlot = outputs.get(PlanStageSlots::kResult);
+ auto [outStage, slots] = projectTopLevelFields(
+ std::move(stage), fields, resultSlot, root->nodeId(), &_slotIdGenerator);
+ stage = std::move(outStage);
+
+ for (size_t i = 0; i < fields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slots[i]);
+ }
+ }
+
+ return {std::move(stage), std::move(outputs)};
}
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder.h b/src/mongo/db/query/sbe_stage_builder.h
index 14bffbc1aa3..081660c43b1 100644
--- a/src/mongo/db/query/sbe_stage_builder.h
+++ b/src/mongo/db/query/sbe_stage_builder.h
@@ -40,6 +40,7 @@
#include "mongo/db/query/sbe_stage_builder_helpers.h"
#include "mongo/db/query/shard_filterer_factory_interface.h"
#include "mongo/db/query/stage_builder.h"
+#include "mongo/util/pair_map.h"
namespace mongo::stage_builder {
/**
@@ -51,6 +52,12 @@ std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment(
OperationContext* opCtx,
sbe::value::SlotIdGenerator* slotIdGenerator);
+class PlanStageReqs;
+class PlanStageSlots;
+sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs,
+ const PlanStageSlots& outputs,
+ const sbe::value::SlotVector& exclude = sbe::makeSV());
+
/**
* This function prepares the SBE tree for execution, such as attaching the OperationContext,
* ensuring that the SBE tree is registered with the PlanYieldPolicySBE and populating the
@@ -105,78 +112,86 @@ struct ParameterizedIndexScanSlots {
*/
class PlanStageSlots {
public:
- static constexpr StringData kResult = "result"_sd;
- static constexpr StringData kRecordId = "recordId"_sd;
- static constexpr StringData kReturnKey = "returnKey"_sd;
- static constexpr StringData kSnapshotId = "snapshotId"_sd;
- static constexpr StringData kIndexId = "indexId"_sd;
- static constexpr StringData kIndexKey = "indexKey"_sd;
- static constexpr StringData kIndexKeyPattern = "indexKeyPattern"_sd;
+ // The _slots map is capable of holding different "classes" of slots:
+ // 1) kMeta slots are used to hold the current document (kResult), record ID (kRecordId), and
+ // various pieces of metadata.
+ // 2) kField slots represent the values of top-level fields, or in some cases of dotted field
+ // paths (when we are getting the dotted field from a non-multikey index and we know no array
+ // traversal is needed). These slots hold the actual values of the fields / field paths (not
+ // the sort key or collation comparison key for the field).
+ // 3) kKey slots represent the raw key value that comes from an ixscan / ixseek stage for a
+ // given field path. This raw key value can be used for sorting / comparison, but it is not
+ // always equal to the actual value of the field path (for example, if the key is coming from
+ // an index that has a non-simple collation).
+ enum class Type {
+ kMeta,
+ kField,
+ kKey,
+ };
+
+ using Name = std::pair<Type, StringData>;
+ using OwnedName = std::pair<Type, std::string>;
+
+ static constexpr auto kField = Type::kField;
+ static constexpr auto kKey = Type::kKey;
+ static constexpr auto kMeta = Type::kMeta;
+
+ static constexpr Name kResult = {kMeta, "result"_sd};
+ static constexpr Name kRecordId = {kMeta, "recordId"_sd};
+ static constexpr Name kReturnKey = {kMeta, "returnKey"_sd};
+ static constexpr Name kSnapshotId = {kMeta, "snapshotId"_sd};
+ static constexpr Name kIndexId = {kMeta, "indexId"_sd};
+ static constexpr Name kIndexKey = {kMeta, "indexKey"_sd};
+ static constexpr Name kIndexKeyPattern = {kMeta, "indexKeyPattern"_sd};
PlanStageSlots() = default;
PlanStageSlots(const PlanStageReqs& reqs, sbe::value::SlotIdGenerator* slotIdGenerator);
- bool has(StringData str) const {
+ bool has(const Name& str) const {
return _slots.count(str);
}
- sbe::value::SlotId get(StringData str) const {
+ sbe::value::SlotId get(const Name& str) const {
auto it = _slots.find(str);
invariant(it != _slots.end());
return it->second;
}
- boost::optional<sbe::value::SlotId> getIfExists(StringData str) const {
+ boost::optional<sbe::value::SlotId> getIfExists(const Name& str) const {
if (auto it = _slots.find(str); it != _slots.end()) {
return it->second;
}
return boost::none;
}
- void set(StringData str, sbe::value::SlotId slot) {
- _slots[str] = slot;
- }
-
- void clear(StringData str) {
- _slots.erase(str);
+ void set(const Name& str, sbe::value::SlotId slot) {
+ _slots.insert_or_assign(str, slot);
}
- const boost::optional<sbe::value::SlotVector>& getIndexKeySlots() const {
- return _indexKeySlots;
+ void set(OwnedName str, sbe::value::SlotId slot) {
+ _slots.insert_or_assign(std::move(str), slot);
}
- boost::optional<sbe::value::SlotVector> extractIndexKeySlots() {
- ON_BLOCK_EXIT([this] { _indexKeySlots = boost::none; });
- return std::move(_indexKeySlots);
- }
-
- void setIndexKeySlots(sbe::value::SlotVector iks) {
- _indexKeySlots = std::move(iks);
- }
-
- void setIndexKeySlots(boost::optional<sbe::value::SlotVector> iks) {
- _indexKeySlots = std::move(iks);
+ void clear(const Name& str) {
+ _slots.erase(str);
}
/**
- * This method applies an action to some/all of the slots within this struct (excluding index
- * key slots). For each slot in this struct, the action is will be applied to the slot if (and
- * only if) the corresponding flag in 'reqs' is true.
+ * This method applies an action to some/all of the slots within this struct. For each slot in
+ * this struct, the action is will be applied to the slot if (and only if) the corresponding
+ * flag in 'reqs' is true.
*/
inline void forEachSlot(const PlanStageReqs& reqs,
const std::function<void(sbe::value::SlotId)>& fn) const;
- inline void forEachSlot(
- const PlanStageReqs& reqs,
- const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const;
+ inline void forEachSlot(const PlanStageReqs& reqs,
+ const std::function<void(sbe::value::SlotId, const Name&)>& fn) const;
-private:
- StringMap<sbe::value::SlotId> _slots;
+ inline void clearNonRequiredSlots(const PlanStageReqs& reqs);
- // When an index scan produces parts of an index key for a covered plan, this is where the
- // slots for the produced values are stored.
- boost::optional<sbe::value::SlotVector> _indexKeySlots;
+private:
+ PairMap<Type, std::string, sbe::value::SlotId> _slots;
};
/**
@@ -185,38 +200,71 @@ private:
*/
class PlanStageReqs {
public:
+ using Type = PlanStageSlots::Type;
+ using Name = PlanStageSlots::Name;
+ using OwnedName = std::pair<Type, std::string>;
+
+ static constexpr auto kField = PlanStageSlots::Type::kField;
+ static constexpr auto kKey = PlanStageSlots::Type::kKey;
+ static constexpr auto kMeta = PlanStageSlots::Type::kMeta;
+
PlanStageReqs copy() const {
return *this;
}
- bool has(StringData str) const {
+ bool has(const Name& str) const {
auto it = _slots.find(str);
return it != _slots.end() && it->second;
}
- PlanStageReqs& set(StringData str) {
- _slots[str] = true;
+ PlanStageReqs& set(const Name& str) {
+ _slots.insert_or_assign(str, true);
+ return *this;
+ }
+
+ PlanStageReqs& set(OwnedName str) {
+ _slots.insert_or_assign(std::move(str), true);
+ return *this;
+ }
+
+ PlanStageReqs& set(const std::vector<Name>& strs) {
+ for (size_t i = 0; i < strs.size(); ++i) {
+ _slots.insert_or_assign(strs[i], true);
+ }
+ return *this;
+ }
+
+ PlanStageReqs& set(std::vector<OwnedName> strs) {
+ for (size_t i = 0; i < strs.size(); ++i) {
+ _slots.insert_or_assign(std::move(strs[i]), true);
+ }
return *this;
}
- PlanStageReqs& setIf(StringData str, bool condition) {
+ PlanStageReqs& setIf(const Name& str, bool condition) {
if (condition) {
- _slots[str] = true;
+ _slots.insert_or_assign(str, true);
}
return *this;
}
- PlanStageReqs& clear(StringData str) {
- _slots.erase(str);
+ PlanStageReqs& setFields(std::vector<std::string> strs) {
+ for (size_t i = 0; i < strs.size(); ++i) {
+ _slots.insert_or_assign(std::make_pair(kField, std::move(strs[i])), true);
+ }
return *this;
}
- boost::optional<sbe::IndexKeysInclusionSet>& getIndexKeyBitset() {
- return _indexKeyBitset;
+ PlanStageReqs& setKeys(std::vector<std::string> strs) {
+ for (size_t i = 0; i < strs.size(); ++i) {
+ _slots.insert_or_assign(std::make_pair(kKey, std::move(strs[i])), true);
+ }
+ return *this;
}
- const boost::optional<sbe::IndexKeysInclusionSet>& getIndexKeyBitset() const {
- return _indexKeyBitset;
+ PlanStageReqs& clear(const Name& str) {
+ _slots.erase(str);
+ return *this;
}
bool getIsBuildingUnionForTailableCollScan() const {
@@ -243,6 +291,52 @@ public:
return _targetNamespace;
}
+ bool hasType(Type t) const {
+ for (auto&& [name, isRequired] : _slots) {
+ if (isRequired && name.first == t) {
+ return true;
+ }
+ }
+ return false;
+ }
+ bool hasFields() const {
+ return hasType(kField);
+ }
+ bool hasKeys() const {
+ return hasType(kKey);
+ }
+
+ std::vector<std::string> getOfType(Type t) const {
+ std::vector<std::string> res;
+ for (auto&& [name, isRequired] : _slots) {
+ if (isRequired && name.first == t) {
+ res.push_back(name.second);
+ }
+ }
+ std::sort(res.begin(), res.end());
+ return res;
+ }
+ std::vector<std::string> getFields() const {
+ return getOfType(kField);
+ }
+ std::vector<std::string> getKeys() const {
+ return getOfType(kKey);
+ }
+
+ PlanStageReqs& clearAllOfType(Type t) {
+ auto fields = getOfType(t);
+ for (const auto& field : fields) {
+ _slots.erase(std::make_pair(kField, StringData(field)));
+ }
+ return *this;
+ }
+ PlanStageReqs& clearAllFields() {
+ return clearAllOfType(kField);
+ }
+ PlanStageReqs& clearAllKeys() {
+ return clearAllOfType(kKey);
+ }
+
friend PlanStageSlots::PlanStageSlots(const PlanStageReqs& reqs,
sbe::value::SlotIdGenerator* slotIdGenerator);
@@ -251,14 +345,12 @@ public:
friend void PlanStageSlots::forEachSlot(
const PlanStageReqs& reqs,
- const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const;
+ const std::function<void(sbe::value::SlotId, const Name&)>& fn) const;
-private:
- StringMap<bool> _slots;
+ friend void PlanStageSlots::clearNonRequiredSlots(const PlanStageReqs& reqs);
- // A bitset here indicates that we have a covered projection that is expecting to parts of the
- // index key from an index scan.
- boost::optional<sbe::IndexKeysInclusionSet> _indexKeyBitset;
+private:
+ PairMap<Type, std::string, bool> _slots;
// When we're in the middle of building a special union sub-tree implementing a tailable cursor
// collection scan, this flag will be set to true. Otherwise this flag will be false.
@@ -283,11 +375,11 @@ void PlanStageSlots::forEachSlot(const PlanStageReqs& reqs,
// Clang raises an error if we attempt to use 'name' in the tassert() below, because
// tassert() is a macro that uses lambdas and 'name' is defined via "local binding".
// We work-around this by copying 'name' to a local variable 'slotName'.
- auto slotName = StringData(name);
+ auto slotName = Name(name);
auto it = _slots.find(slotName);
tassert(7050900,
- str::stream() << "Could not find '" << slotName
- << "' slot in the map, expected slot to exist",
+ str::stream() << "Could not find " << static_cast<int>(slotName.first) << ":'"
+ << slotName.second << "' in the slot map, expected slot to exist",
it != _slots.end());
fn(it->second);
@@ -297,17 +389,17 @@ void PlanStageSlots::forEachSlot(const PlanStageReqs& reqs,
void PlanStageSlots::forEachSlot(
const PlanStageReqs& reqs,
- const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const {
+ const std::function<void(sbe::value::SlotId, const Name&)>& fn) const {
for (auto&& [name, isRequired] : reqs._slots) {
if (isRequired) {
// Clang raises an error if we attempt to use 'name' in the tassert() below, because
// tassert() is a macro that uses lambdas and 'name' is defined via "local binding".
// We work-around this by copying 'name' to a local variable 'slotName'.
- auto slotName = StringData(name);
+ auto slotName = Name(name);
auto it = _slots.find(slotName);
tassert(7050901,
- str::stream() << "Could not find '" << slotName
- << "' slot in the map, expected slot to exist",
+ str::stream() << "Could not find " << static_cast<int>(slotName.first) << ":'"
+ << slotName.second << "' in the slot map, expected slot to exist",
it != _slots.end());
fn(it->second, slotName);
@@ -315,6 +407,21 @@ void PlanStageSlots::forEachSlot(
}
}
+void PlanStageSlots::clearNonRequiredSlots(const PlanStageReqs& reqs) {
+ auto it = _slots.begin();
+ while (it != _slots.end()) {
+ auto& name = it->first;
+ auto reqIt = reqs._slots.find(name);
+ // We never clear kResult, regardless of whether it is required by 'reqs'.
+ if ((reqIt != reqs._slots.end() && reqIt->second) ||
+ (name.first == kResult.first && name.second == kResult.second)) {
+ ++it;
+ } else {
+ _slots.erase(it++);
+ }
+ }
+}
+
using InputParamToSlotMap = stdx::unordered_map<MatchExpression::InputParamId, sbe::value::SlotId>;
using VariableIdToSlotMap = stdx::unordered_map<Variables::Id, sbe::value::SlotId>;
@@ -443,13 +550,13 @@ private:
*/
class SlotBasedStageBuilder final : public StageBuilder<sbe::PlanStage> {
public:
- static constexpr StringData kResult = PlanStageSlots::kResult;
- static constexpr StringData kRecordId = PlanStageSlots::kRecordId;
- static constexpr StringData kReturnKey = PlanStageSlots::kReturnKey;
- static constexpr StringData kSnapshotId = PlanStageSlots::kSnapshotId;
- static constexpr StringData kIndexId = PlanStageSlots::kIndexId;
- static constexpr StringData kIndexKey = PlanStageSlots::kIndexKey;
- static constexpr StringData kIndexKeyPattern = PlanStageSlots::kIndexKeyPattern;
+ static constexpr auto kResult = PlanStageSlots::kResult;
+ static constexpr auto kRecordId = PlanStageSlots::kRecordId;
+ static constexpr auto kReturnKey = PlanStageSlots::kReturnKey;
+ static constexpr auto kSnapshotId = PlanStageSlots::kSnapshotId;
+ static constexpr auto kIndexId = PlanStageSlots::kIndexId;
+ static constexpr auto kIndexKey = PlanStageSlots::kIndexKey;
+ static constexpr auto kIndexKeyPattern = PlanStageSlots::kIndexKeyPattern;
SlotBasedStageBuilder(OperationContext* opCtx,
const MultipleCollectionAccessor& collections,
@@ -457,6 +564,12 @@ public:
const QuerySolution& solution,
PlanYieldPolicySBE* yieldPolicy);
+ /**
+ * This method will build an SBE PlanStage tree for QuerySolutionNode 'root' and its
+ * descendents.
+ *
+ * This method is a wrapper around 'build(const QuerySolutionNode*, const PlanStageReqs&)'.
+ */
std::unique_ptr<sbe::PlanStage> build(const QuerySolutionNode* root) final;
PlanStageData getPlanStageData() {
@@ -464,6 +577,14 @@ public:
}
private:
+ /**
+ * This method will build an SBE PlanStage tree for QuerySolutionNode 'root' and its
+ * descendents.
+ *
+ * Based on the type of 'root', this method will dispatch to the appropriate buildXXX() method.
+ * This method will also handle generating calls to getField() to satisfy kField reqs that were
+ * not satisfied by the buildXXX() method.
+ */
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> build(const QuerySolutionNode* node,
const PlanStageReqs& reqs);
@@ -494,7 +615,7 @@ private:
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortCovered(
const QuerySolutionNode* root, const PlanStageReqs& reqs);
- std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortKeyGeneraror(
+ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortKeyGenerator(
const QuerySolutionNode* root, const PlanStageReqs& reqs);
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortMerge(
@@ -539,19 +660,14 @@ private:
const QuerySolutionNode* root, const PlanStageReqs& reqs);
/**
- * Constructs an optimized SBE plan for 'filterNode' in the case that the fields of the
- * 'shardKeyPattern' are provided by 'childIxscan'. In this case, the SBE plan for the child
+ * Constructs an optimized SBE plan for 'root' in the case that the fields of the shard key
+ * pattern are provided by the child index scan. In this case, the SBE plan for the child
* index scan node will fill out slots for the necessary components of the index key. These
* slots can be read directly in order to determine the shard key that should be passed to the
* 'shardFiltererSlot'.
*/
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildShardFilterCovered(
- const ShardingFilterNode* filterNode,
- sbe::value::SlotId shardFiltererSlot,
- BSONObj shardKeyPattern,
- BSONObj indexKeyPattern,
- const QuerySolutionNode* child,
- PlanStageReqs childReqs);
+ const QuerySolutionNode* root, const PlanStageReqs& reqs);
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildGroup(
const QuerySolutionNode* root, const PlanStageReqs& reqs);
diff --git a/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp b/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp
index c44c088ac87..34a6b8a00a6 100644
--- a/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp
@@ -158,7 +158,7 @@ std::unique_ptr<sbe::PlanStage> buildResumeFromRecordIdSubtree(
std::move(seekRecordIdExpression));
// Construct a 'seek' branch of the 'union'. If we're succeeded to reposition the cursor,
- // the branch will output the 'seekSlot' to start the real scan from, otherwise it will
+ // the branch will output the 'seekSlot' to start the real scan from, otherwise it will
// produce EOF.
auto seekBranch =
sbe::makeS<sbe::LoopJoinStage>(std::move(projStage),
@@ -232,7 +232,7 @@ std::unique_ptr<sbe::PlanStage> buildResumeFromRecordIdSubtree(
* Creates a collection scan sub-tree optimized for oplog scans. We can built an optimized scan
* when any of the following scenarios apply:
*
- * 1. There is a predicted on the 'ts' field of the oplog collection.
+ * 1. There is a predicate on the 'ts' field of the oplog collection.
* 1.1 If a lower bound on 'ts' is present, the collection scan will seek directly to the
* RecordId of an oplog entry as close to this lower bound as possible without going higher.
* 1.2 If the query is *only* a lower bound on 'ts' on a forward scan, every document in the
@@ -248,6 +248,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
StageBuilderState& state,
const CollectionPtr& collection,
const CollectionScanNode* csn,
+ const std::vector<std::string>& fields,
PlanYieldPolicy* yieldPolicy,
bool isTailableResumeBranch) {
invariant(collection->ns().isOplog());
@@ -258,6 +259,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
// Oplog scan optimizations can only be done for a forward scan.
invariant(csn->direction == CollectionScanParams::FORWARD);
+ auto fieldSlots = state.slotIdGenerator->generateMultiple(fields.size());
+
auto resultSlot = state.slotId();
auto recordIdSlot = state.slotId();
@@ -290,9 +293,16 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
// of if we need to track the latest oplog timestamp.
const auto shouldTrackLatestOplogTimestamp =
(csn->maxRecord || csn->shouldTrackLatestOplogTimestamp);
- auto&& [fields, slots, tsSlot] = makeOplogTimestampSlotsIfNeeded(
+ auto&& [scanFields, scanFieldSlots, tsSlot] = makeOplogTimestampSlotsIfNeeded(
state.data->env, state.slotIdGenerator, shouldTrackLatestOplogTimestamp);
+ bool createScanWithAndWithoutFilter = (csn->filter && csn->stopApplyingFilterAfterFirstMatch);
+
+ if (!createScanWithAndWithoutFilter) {
+ scanFields.insert(scanFields.end(), fields.begin(), fields.end());
+ scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end());
+ }
+
sbe::ScanCallbacks callbacks({}, {}, makeOpenCallbackIfNeeded(collection, csn));
auto stage = sbe::makeS<sbe::ScanStage>(collection->uuid(),
resultSlot,
@@ -302,8 +312,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
boost::none /* indexKeySlot */,
boost::none /* keyPatternSlot */,
tsSlot,
- std::move(fields),
- std::move(slots),
+ std::move(scanFields),
+ std::move(scanFieldSlots),
seekRecordIdSlot,
true /* forward */,
yieldPolicy,
@@ -342,7 +352,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
auto oObjSlot = state.slotId();
auto minTsSlot = state.slotId();
sbe::value::SlotVector minTsSlots = {minTsSlot, opTypeSlot, oObjSlot};
- std::vector<std::string> fields = {repl::OpTime::kTimestampFieldName.toString(), "op", "o"};
+ std::vector<std::string> minTsFields = {
+ repl::OpTime::kTimestampFieldName.toString(), "op", "o"};
// If the first entry we see in the oplog is the replset initialization, then it doesn't
// matter if its timestamp is later than the specified minTs; no events earlier than the
@@ -375,7 +386,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
boost::none /* indexKeySlot */,
boost::none /* keyPatternSlot */,
boost::none /* oplogTsSlot*/,
- std::move(fields),
+ std::move(minTsFields),
minTsSlots, /* don't move this */
boost::none,
true /* forward */,
@@ -421,6 +432,18 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
tsSlot = state.slotId();
auto outputSlots = sbe::makeSV(resultSlot, recordIdSlot, *tsSlot);
+ if (!createScanWithAndWithoutFilter) {
+ auto unusedFieldSlots = state.slotIdGenerator->generateMultiple(fieldSlots.size());
+ minTsSlots.insert(minTsSlots.end(), unusedFieldSlots.begin(), unusedFieldSlots.end());
+
+ realSlots.insert(realSlots.end(), fieldSlots.begin(), fieldSlots.end());
+
+ size_t numFieldSlots = fieldSlots.size();
+ fieldSlots = state.slotIdGenerator->generateMultiple(numFieldSlots);
+
+ outputSlots.insert(outputSlots.end(), fieldSlots.begin(), fieldSlots.end());
+ }
+
// Create the union stage. The left branch, which runs first, is our resumability check.
stage = sbe::makeS<sbe::UnionStage>(
sbe::makeSs(std::move(minTsBranch), std::move(stage)),
@@ -454,6 +477,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
relevantSlots.push_back(*tsSlot);
}
+ relevantSlots.insert(relevantSlots.end(), fieldSlots.begin(), fieldSlots.end());
+
auto [_, outputStage] = generateFilter(state,
csn->filter.get(),
{std::move(stage), std::move(relevantSlots)},
@@ -483,7 +508,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
// matches. This RecordId is then used as a starting point of the collection scan in the
// inner branch, and the execution will continue from this point further on, without
// applying the filter.
- if (csn->stopApplyingFilterAfterFirstMatch) {
+ if (createScanWithAndWithoutFilter) {
invariant(!csn->maxRecord);
invariant(csn->direction == CollectionScanParams::FORWARD);
@@ -491,9 +516,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
resultSlot = state.slotId();
recordIdSlot = state.slotId();
- std::tie(fields, slots, tsSlot) = makeOplogTimestampSlotsIfNeeded(
+ std::tie(scanFields, scanFieldSlots, tsSlot) = makeOplogTimestampSlotsIfNeeded(
state.data->env, state.slotIdGenerator, shouldTrackLatestOplogTimestamp);
+ scanFields.insert(scanFields.end(), fields.begin(), fields.end());
+ scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end());
+
stage = sbe::makeS<sbe::LoopJoinStage>(
sbe::makeS<sbe::LimitSkipStage>(std::move(stage), 1, boost::none, csn->nodeId()),
sbe::makeS<sbe::ScanStage>(collection->uuid(),
@@ -504,8 +532,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
boost::none /* indexKeySlot */,
boost::none /* keyPatternSlot */,
tsSlot,
- std::move(fields),
- std::move(slots),
+ std::move(scanFields),
+ std::move(scanFieldSlots),
seekRecordIdSlot,
true /* forward */,
yieldPolicy,
@@ -524,6 +552,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo
PlanStageSlots outputs;
outputs.set(PlanStageSlots::kResult, resultSlot);
outputs.set(PlanStageSlots::kRecordId, recordIdSlot);
+ for (size_t i = 0; i < fields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]);
+ }
return {std::move(stage), std::move(outputs)};
}
@@ -540,6 +571,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
StageBuilderState& state,
const CollectionPtr& collection,
const CollectionScanNode* csn,
+ const std::vector<std::string>& fields,
PlanYieldPolicy* yieldPolicy,
bool isTailableResumeBranch) {
const auto forward = csn->direction == CollectionScanParams::FORWARD;
@@ -548,6 +580,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
invariant(!csn->resumeAfterRecordId || forward);
invariant(!csn->resumeAfterRecordId || !csn->tailable);
+ auto fieldSlots = state.slotIdGenerator->generateMultiple(fields.size());
+
auto resultSlot = state.slotId();
auto recordIdSlot = state.slotId();
auto [seekRecordIdSlot, seekRecordIdExpression] =
@@ -563,9 +597,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
}();
// See if we need to project out an oplog latest timestamp.
- auto&& [fields, slots, tsSlot] = makeOplogTimestampSlotsIfNeeded(
+ auto&& [scanFields, scanFieldSlots, tsSlot] = makeOplogTimestampSlotsIfNeeded(
state.data->env, state.slotIdGenerator, csn->shouldTrackLatestOplogTimestamp);
+ scanFields.insert(scanFields.end(), fields.begin(), fields.end());
+ scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end());
+
sbe::ScanCallbacks callbacks({}, {}, makeOpenCallbackIfNeeded(collection, csn));
auto stage = sbe::makeS<sbe::ScanStage>(collection->uuid(),
resultSlot,
@@ -575,8 +612,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
boost::none /* indexKeySlot */,
boost::none /* keyPatternSlot */,
tsSlot,
- std::move(fields),
- std::move(slots),
+ std::move(scanFields),
+ std::move(scanFieldSlots),
seekRecordIdSlot,
forward,
yieldPolicy,
@@ -602,6 +639,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
invariant(!csn->stopApplyingFilterAfterFirstMatch);
auto relevantSlots = sbe::makeSV(resultSlot, recordIdSlot);
+ relevantSlots.insert(relevantSlots.end(), fieldSlots.begin(), fieldSlots.end());
auto [_, outputStage] = generateFilter(state,
csn->filter.get(),
@@ -614,6 +652,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc
PlanStageSlots outputs;
outputs.set(PlanStageSlots::kResult, resultSlot);
outputs.set(PlanStageSlots::kRecordId, recordIdSlot);
+ for (size_t i = 0; i < fields.size(); ++i) {
+ outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]);
+ }
return {std::move(stage), std::move(outputs)};
}
@@ -623,13 +664,15 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateCollScan(
StageBuilderState& state,
const CollectionPtr& collection,
const CollectionScanNode* csn,
+ const std::vector<std::string>& fields,
PlanYieldPolicy* yieldPolicy,
bool isTailableResumeBranch) {
if (csn->minRecord || csn->maxRecord || csn->stopApplyingFilterAfterFirstMatch) {
return generateOptimizedOplogScan(
- state, collection, csn, yieldPolicy, isTailableResumeBranch);
+ state, collection, csn, fields, yieldPolicy, isTailableResumeBranch);
} else {
- return generateGenericCollScan(state, collection, csn, yieldPolicy, isTailableResumeBranch);
+ return generateGenericCollScan(
+ state, collection, csn, fields, yieldPolicy, isTailableResumeBranch);
}
}
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder_coll_scan.h b/src/mongo/db/query/sbe_stage_builder_coll_scan.h
index b204c5487a8..3b6e1b861ae 100644
--- a/src/mongo/db/query/sbe_stage_builder_coll_scan.h
+++ b/src/mongo/db/query/sbe_stage_builder_coll_scan.h
@@ -41,7 +41,10 @@ namespace mongo::stage_builder {
class PlanStageSlots;
/**
- * Generates an SBE plan stage sub-tree implementing an collection scan.
+ * Generates an SBE plan stage sub-tree implementing an collection scan. 'fields' can be used to
+ * specify top-level fields that should be retrieved during the scan. For each name in 'fields',
+ * there will be a corresponding kField slot in the PlanStageSlots object returned with the same
+ * name.
*
* On success, a tuple containing the following data is returned:
* * A slot to access a fetched document (a resultSlot)
@@ -56,6 +59,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateCollScan(
StageBuilderState& state,
const CollectionPtr& collection,
const CollectionScanNode* csn,
+ const std::vector<std::string>& fields,
PlanYieldPolicy* yieldPolicy,
bool isTailableResumeBranch);
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index d96f7ea6a92..c1566970b6d 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -76,8 +76,9 @@ struct ExpressionVisitorContext {
ExpressionVisitorContext(StageBuilderState& state,
EvalStage inputStage,
boost::optional<sbe::value::SlotId> optionalRootSlot,
- PlanNodeId planNodeId)
- : state(state), optionalRootSlot(optionalRootSlot), planNodeId(planNodeId) {
+ PlanNodeId planNodeId,
+ const PlanStageSlots* slots = nullptr)
+ : state(state), optionalRootSlot(optionalRootSlot), planNodeId(planNodeId), slots(slots) {
evalStack.emplaceFrame(std::move(inputStage));
}
@@ -142,6 +143,9 @@ struct ExpressionVisitorContext {
std::stack<VarsFrame> varsFrameStack;
// The id of the QuerySolutionNode to which the expression we are converting to SBE is attached.
const PlanNodeId planNodeId;
+
+ const PlanStageSlots* slots = nullptr;
+
// This stack contains slot id for the current element variable of $filter expression.
std::stack<sbe::value::SlotId> filterExprSlotIdStack;
// We use this counter to track which children of $filter we've already processed.
@@ -149,17 +153,23 @@ struct ExpressionVisitorContext {
};
std::unique_ptr<sbe::EExpression> generateTraverseHelper(
- const sbe::EVariable& inputVar,
+ std::unique_ptr<sbe::EExpression> inputExpr,
const FieldPath& fp,
size_t level,
- sbe::value::FrameIdGenerator* frameIdGenerator) {
+ sbe::value::FrameIdGenerator* frameIdGenerator,
+ boost::optional<sbe::value::SlotId> topLevelFieldSlot = boost::none) {
using namespace std::literals;
invariant(level < fp.getPathLength());
+ tassert(6023417,
+ "Expected an input expression or top level field",
+ inputExpr.get() || topLevelFieldSlot.has_value());
// Generate an expression to read a sub-field at the current nested level.
auto fieldName = sbe::makeE<sbe::EConstant>(fp.getFieldName(level));
- auto fieldExpr = makeFunction("getField"_sd, inputVar.clone(), std::move(fieldName));
+ auto fieldExpr = topLevelFieldSlot
+ ? makeVariable(*topLevelFieldSlot)
+ : makeFunction("getField"_sd, std::move(inputExpr), std::move(fieldName));
if (level == fp.getPathLength() - 1) {
// For the last level, we can just return the field slot without the need for a
@@ -171,7 +181,7 @@ std::unique_ptr<sbe::EExpression> generateTraverseHelper(
auto lambdaFrameId = frameIdGenerator->generate();
auto lambdaParam = sbe::EVariable{lambdaFrameId, 0};
- auto resultExpr = generateTraverseHelper(lambdaParam, fp, level + 1, frameIdGenerator);
+ auto resultExpr = generateTraverseHelper(lambdaParam.clone(), fp, level + 1, frameIdGenerator);
auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr));
@@ -186,28 +196,32 @@ std::unique_ptr<sbe::EExpression> generateTraverseHelper(
* For the given MatchExpression 'expr', generates a path traversal SBE plan stage sub-tree
* implementing the comparison expression.
*/
-std::unique_ptr<sbe::EExpression> generateTraverse(const sbe::EVariable& inputVar,
- bool expectsDocumentInputOnly,
- const FieldPath& fp,
- sbe::value::FrameIdGenerator* frameIdGenerator) {
+std::unique_ptr<sbe::EExpression> generateTraverse(
+ std::unique_ptr<sbe::EExpression> inputExpr,
+ bool expectsDocumentInputOnly,
+ const FieldPath& fp,
+ sbe::value::FrameIdGenerator* frameIdGenerator,
+ boost::optional<sbe::value::SlotId> topLevelFieldSlot = boost::none) {
size_t level = 0;
if (expectsDocumentInputOnly) {
- // When we know for sure that 'inputVar' will be a document and _not_ an array (such as
- // when traversing the root document), we can generate a simpler expression.
- return generateTraverseHelper(inputVar, fp, level, frameIdGenerator);
+ // When we know for sure that 'inputExpr' will be a document and _not_ an array (such as
+ // when accessing a field on the root document), we can generate a simpler expression.
+ return generateTraverseHelper(
+ std::move(inputExpr), fp, level, frameIdGenerator, topLevelFieldSlot);
} else {
- // The general case: the value in the 'inputVar' may be an array that will require
+ tassert(6023418, "Expected an input expression", inputExpr.get());
+ // The general case: the value in the 'inputExpr' may be an array that will require
// traversal.
auto lambdaFrameId = frameIdGenerator->generate();
auto lambdaParam = sbe::EVariable{lambdaFrameId, 0};
- auto resultExpr = generateTraverseHelper(lambdaParam, fp, level, frameIdGenerator);
+ auto resultExpr = generateTraverseHelper(lambdaParam.clone(), fp, level, frameIdGenerator);
auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr));
return makeFunction("traverseP",
- inputVar.clone(),
+ std::move(inputExpr),
std::move(lambdaExpr),
makeConstant(sbe::value::TypeTags::NumberInt32, 1));
}
@@ -1951,28 +1965,44 @@ public:
void visit(const ExpressionFieldPath* expr) final {
// There's a chance that we've already generated a SBE plan stage tree for this field path,
// in which case we avoid regeneration of the same plan stage tree.
- if (auto it = _context->state.preGeneratedExprs.find(expr->getFieldPath().fullPath());
- it != _context->state.preGeneratedExprs.end()) {
- tassert(6089301,
- "Expressions for top-level document or a variable must not be pre-generated",
- expr->getFieldPath().getPathLength() != 1 && !expr->isVariableReference());
- if (auto optionalSlot = it->second.getSlot(); optionalSlot) {
- _context->pushExpr(*optionalSlot);
- } else {
- auto preGeneratedExpr = it->second.extractExpr();
- _context->pushExpr(preGeneratedExpr->clone());
- it->second = std::move(preGeneratedExpr);
+ if (!_context->state.preGeneratedExprs.empty()) {
+ if (auto it = _context->state.preGeneratedExprs.find(expr->getFieldPath().fullPath());
+ it != _context->state.preGeneratedExprs.end()) {
+ tassert(6089301,
+ "Expressions for top-level documents / variables must not be pre-generated",
+ expr->getFieldPath().getPathLength() != 1 && !expr->isVariableReference());
+ if (auto optionalSlot = it->second.getSlot(); optionalSlot) {
+ _context->pushExpr(*optionalSlot);
+ } else {
+ auto preGeneratedExpr = it->second.extractExpr();
+ _context->pushExpr(preGeneratedExpr->clone());
+ it->second = std::move(preGeneratedExpr);
+ }
+ return;
}
- return;
}
- tassert(6075901, "Must have a valid root slot", _context->optionalRootSlot.has_value());
+ boost::optional<sbe::value::SlotId> slotId;
+ boost::optional<sbe::value::SlotId> topLevelFieldSlot;
+ boost::optional<FieldPath> fp;
+ bool expectsDocumentInputOnly = false;
- sbe::value::SlotId slotId;
+ if (expr->getFieldPath().getPathLength() > 1) {
+ fp = expr->getFieldPathWithoutCurrentPrefix();
+ }
+
+ if (expr->getVariableId() == Variables::kRootId &&
+ expr->getFieldPath().getPathLength() > 1) {
+ slotId = _context->optionalRootSlot;
+ expectsDocumentInputOnly = true;
- if (!Variables::isUserDefinedVariable(expr->getVariableId())) {
+ if (_context->slots) {
+ auto topLevelField = std::make_pair(PlanStageSlots::kField, fp->front());
+ topLevelFieldSlot = _context->slots->getIfExists(topLevelField);
+ }
+ } else if (!Variables::isUserDefinedVariable(expr->getVariableId())) {
if (expr->getVariableId() == Variables::kRootId) {
- slotId = *(_context->optionalRootSlot);
+ slotId = _context->optionalRootSlot;
} else if (expr->getVariableId() == Variables::kRemoveId) {
// For the field paths that begin with "$$REMOVE", we always produce Nothing,
// so no traversal is necessary.
@@ -1990,7 +2020,7 @@ public:
<< "Builtin variable '$$" << it->second << "' is not available",
variableSlot.has_value());
- slotId = *variableSlot;
+ slotId = variableSlot;
}
} else {
auto it = _context->environment.find(expr->getVariableId());
@@ -2002,19 +2032,27 @@ public:
}
if (expr->getFieldPath().getPathLength() == 1) {
+ tassert(6023419, "Must have a valid slot", slotId.has_value());
+
// A solo variable reference (e.g.: "$$ROOT" or "$$myvar") that doesn't need any
// traversal.
- _context->pushExpr(slotId);
+ _context->pushExpr(*slotId);
return;
}
- // Dereference a dotted path, which may contain arrays requiring implicit traversal.
- const bool expectsDocumentInputOnly = slotId == *(_context->optionalRootSlot);
+ tassert(6023420,
+ "Must have a valid root slot or field slot",
+ slotId.has_value() || topLevelFieldSlot.has_value());
+
+ auto inputExpr =
+ slotId ? sbe::makeE<sbe::EVariable>(*slotId) : std::unique_ptr<sbe::EExpression>{};
- auto resultExpr = generateTraverse(sbe::EVariable{slotId},
+ // Dereference a dotted path, which may contain arrays requiring implicit traversal.
+ auto resultExpr = generateTraverse(std::move(inputExpr),
expectsDocumentInputOnly,
- expr->getFieldPathWithoutCurrentPrefix(),
- _context->state.frameIdGenerator);
+ *fp,
+ _context->state.frameIdGenerator,
+ topLevelFieldSlot);
_context->pushExpr(std::move(resultExpr));
}
@@ -4126,8 +4164,9 @@ EvalExprStagePair generateExpression(StageBuilderState& state,
const Expression* expr,
EvalStage stage,
boost::optional<sbe::value::SlotId> optionalRootSlot,
- PlanNodeId planNodeId) {
- ExpressionVisitorContext context(state, std::move(stage), optionalRootSlot, planNodeId);
+ PlanNodeId planNodeId,
+ const PlanStageSlots* slots) {
+ ExpressionVisitorContext context(state, std::move(stage), optionalRootSlot, planNodeId, slots);
ExpressionPreVisitor preVisitor{&context};
ExpressionInVisitor inVisitor{&context};
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.h b/src/mongo/db/query/sbe_stage_builder_expression.h
index 3d148113f89..78b4774f807 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.h
+++ b/src/mongo/db/query/sbe_stage_builder_expression.h
@@ -38,6 +38,8 @@
#include "mongo/db/query/sbe_stage_builder_helpers.h"
namespace mongo::stage_builder {
+class PlanStageSlots;
+
/**
* Translates an input Expression into an SBE EExpression. The 'stage' parameter provides the input
* subtree to build on top of.
@@ -46,7 +48,8 @@ EvalExprStagePair generateExpression(StageBuilderState& state,
const Expression* expr,
EvalStage stage,
boost::optional<sbe::value::SlotId> optionalRootSlot,
- PlanNodeId planNodeId);
+ PlanNodeId planNodeId,
+ const PlanStageSlots* slots = nullptr);
/**
* Generate an EExpression that converts a value (contained in a variable bound to 'branchRef') that
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index affc05691b1..c8acb608a42 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -958,27 +958,25 @@ bool indexKeyConsistencyCheckCallback(OperationContext* opCtx,
return true;
}
-std::tuple<sbe::value::SlotId, sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>
-makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
- sbe::value::SlotId seekKeySlot,
- sbe::value::SlotId snapshotIdSlot,
- sbe::value::SlotId indexIdSlot,
- sbe::value::SlotId indexKeySlot,
- sbe::value::SlotId indexKeyPatternSlot,
- const CollectionPtr& collToFetch,
- StringMap<const IndexAccessMethod*> iamMap,
- PlanNodeId planNodeId,
- sbe::value::SlotVector slotsToForward,
- sbe::value::SlotIdGenerator& slotIdGenerator) {
+std::unique_ptr<sbe::PlanStage> makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
+ sbe::value::SlotId resultSlot,
+ sbe::value::SlotId recordIdSlot,
+ std::vector<std::string> fields,
+ sbe::value::SlotVector fieldSlots,
+ sbe::value::SlotId seekKeySlot,
+ sbe::value::SlotId snapshotIdSlot,
+ sbe::value::SlotId indexIdSlot,
+ sbe::value::SlotId indexKeySlot,
+ sbe::value::SlotId indexKeyPatternSlot,
+ const CollectionPtr& collToFetch,
+ StringMap<const IndexAccessMethod*> iamMap,
+ PlanNodeId planNodeId,
+ sbe::value::SlotVector slotsToForward) {
// It is assumed that we are generating a fetch loop join over the main collection. If we are
// generating a fetch over a secondary collection, it is the responsibility of a parent node
// in the QSN tree to indicate which collection we are fetching over.
tassert(6355301, "Cannot fetch from a collection that doesn't exist", collToFetch);
- auto resultSlot = slotIdGenerator.generate();
- auto recordIdSlot = slotIdGenerator.generate();
-
- using namespace std::placeholders;
sbe::ScanCallbacks callbacks(
indexKeyCorruptionCheckCallback,
[=](auto&& arg1, auto&& arg2, auto&& arg3, auto&& arg4, auto&& arg5, auto&& arg6) {
@@ -995,8 +993,8 @@ makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
indexKeySlot,
indexKeyPatternSlot,
boost::none,
- std::vector<std::string>{},
- sbe::makeSV(),
+ std::move(fields),
+ std::move(fieldSlots),
seekKeySlot,
true,
nullptr,
@@ -1005,15 +1003,13 @@ makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
// Get the recordIdSlot from the outer side (e.g., IXSCAN) and feed it to the inner side,
// limiting the result set to 1 row.
- auto stage = sbe::makeS<sbe::LoopJoinStage>(
+ return sbe::makeS<sbe::LoopJoinStage>(
std::move(inputStage),
sbe::makeS<sbe::LimitSkipStage>(std::move(scanStage), 1, boost::none, planNodeId),
std::move(slotsToForward),
sbe::makeSV(seekKeySlot, snapshotIdSlot, indexIdSlot, indexKeySlot, indexKeyPatternSlot),
nullptr,
planNodeId);
-
- return {resultSlot, recordIdSlot, std::move(stage)};
}
sbe::value::SlotId StageBuilderState::registerInputParamSlot(
@@ -1121,4 +1117,172 @@ std::unique_ptr<sbe::PlanStage> rehydrateIndexKey(std::unique_ptr<sbe::PlanStage
return sbe::makeProjectStage(std::move(stage), nodeId, resultSlot, std::move(keyExpr));
}
+/**
+ * For covered projections, each of the projection field paths represent respective index key. To
+ * rehydrate index keys into the result object, we first need to convert projection AST into
+ * 'IndexKeyPatternTreeNode' structure. Context structure and visitors below are used for this
+ * purpose.
+ */
+struct IndexKeysBuilderContext {
+ // Contains resulting tree of index keys converted from projection AST.
+ IndexKeyPatternTreeNode root;
+
+ // Full field path of the currently visited projection node.
+ std::vector<StringData> currentFieldPath;
+
+ // Each projection node has a vector of field names. This stack contains indexes of the
+ // currently visited field names for each of the projection nodes.
+ std::vector<size_t> currentFieldIndex;
+};
+
+/**
+ * Covered projections are always inclusion-only, so we ban all other operators.
+ */
+class IndexKeysBuilder : public projection_ast::ProjectionASTConstVisitor {
+public:
+ using projection_ast::ProjectionASTConstVisitor::visit;
+
+ IndexKeysBuilder(IndexKeysBuilderContext* context) : _context{context} {}
+
+ void visit(const projection_ast::ProjectionPositionalASTNode* node) final {
+ tasserted(5474501, "Positional projection is not allowed in simple or covered projections");
+ }
+
+ void visit(const projection_ast::ProjectionSliceASTNode* node) final {
+ tasserted(5474502, "$slice is not allowed in simple or covered projections");
+ }
+
+ void visit(const projection_ast::ProjectionElemMatchASTNode* node) final {
+ tasserted(5474503, "$elemMatch is not allowed in simple or covered projections");
+ }
+
+ void visit(const projection_ast::ExpressionASTNode* node) final {
+ tasserted(5474504, "Expressions are not allowed in simple or covered projections");
+ }
+
+ void visit(const projection_ast::MatchExpressionASTNode* node) final {
+ tasserted(
+ 5474505,
+ "$elemMatch / positional projection are not allowed in simple or covered projections");
+ }
+
+ void visit(const projection_ast::BooleanConstantASTNode* node) override {}
+
+protected:
+ IndexKeysBuilderContext* _context;
+};
+
+class IndexKeysPreBuilder final : public IndexKeysBuilder {
+public:
+ using IndexKeysBuilder::IndexKeysBuilder;
+ using IndexKeysBuilder::visit;
+
+ void visit(const projection_ast::ProjectionPathASTNode* node) final {
+ _context->currentFieldIndex.push_back(0);
+ _context->currentFieldPath.emplace_back(node->fieldNames().front());
+ }
+};
+
+class IndexKeysInBuilder final : public IndexKeysBuilder {
+public:
+ using IndexKeysBuilder::IndexKeysBuilder;
+ using IndexKeysBuilder::visit;
+
+ void visit(const projection_ast::ProjectionPathASTNode* node) final {
+ auto& currentIndex = _context->currentFieldIndex.back();
+ currentIndex++;
+ _context->currentFieldPath.back() = node->fieldNames()[currentIndex];
+ }
+};
+
+class IndexKeysPostBuilder final : public IndexKeysBuilder {
+public:
+ using IndexKeysBuilder::IndexKeysBuilder;
+ using IndexKeysBuilder::visit;
+
+ void visit(const projection_ast::ProjectionPathASTNode* node) final {
+ _context->currentFieldIndex.pop_back();
+ _context->currentFieldPath.pop_back();
+ }
+
+ void visit(const projection_ast::BooleanConstantASTNode* constantNode) final {
+ if (!constantNode->value()) {
+ // Even though only inclusion is allowed in covered projection, there still can be
+ // {_id: 0} component. We do not need to generate any nodes for it.
+ return;
+ }
+
+ // Insert current field path into the index keys tree if it does not exist yet.
+ auto* node = &_context->root;
+ for (const auto& part : _context->currentFieldPath) {
+ if (auto it = node->children.find(part); it != node->children.end()) {
+ node = it->second.get();
+ } else {
+ node = node->emplace(part);
+ }
+ }
+ }
+};
+
+IndexKeyPatternTreeNode buildPatternTree(const projection_ast::Projection& projection) {
+ IndexKeysBuilderContext context;
+ IndexKeysPreBuilder preVisitor{&context};
+ IndexKeysInBuilder inVisitor{&context};
+ IndexKeysPostBuilder postVisitor{&context};
+
+ projection_ast::ProjectionASTConstWalker walker{&preVisitor, &inVisitor, &postVisitor};
+
+ tree_walker::walk<true, projection_ast::ASTNode>(projection.root(), &walker);
+
+ return std::move(context.root);
+}
+
+std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectTopLevelFields(
+ std::unique_ptr<sbe::PlanStage> stage,
+ const std::vector<std::string>& fields,
+ sbe::value::SlotId resultSlot,
+ PlanNodeId nodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ // 'outputSlots' will match the order of 'fields'.
+ sbe::value::SlotVector outputSlots;
+ outputSlots.reserve(fields.size());
+
+ sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects;
+ for (size_t i = 0; i < fields.size(); ++i) {
+ const auto& field = fields[i];
+ auto slot = slotIdGenerator->generate();
+ auto getFieldExpr =
+ makeFunction("getField"_sd, makeVariable(resultSlot), makeConstant(field));
+ projects.insert({slot, std::move(getFieldExpr)});
+ outputSlots.emplace_back(slot);
+ }
+
+ if (!projects.empty()) {
+ stage = sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projects), nodeId);
+ }
+
+ return {std::move(stage), std::move(outputSlots)};
+}
+
+std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectNothingToSlots(
+ std::unique_ptr<sbe::PlanStage> stage,
+ size_t n,
+ PlanNodeId nodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ if (n == 0) {
+ return {std::move(stage), sbe::value::SlotVector{}};
+ }
+
+ auto outputSlots = slotIdGenerator->generateMultiple(n);
+
+ sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects;
+ for (size_t i = 0; i < n; ++i) {
+ projects.insert(
+ {outputSlots[i], sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0)});
+ }
+
+ stage = sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projects), nodeId);
+
+ return {std::move(stage), std::move(outputSlots)};
+}
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index f3b5d9ca455..634005c504c 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -44,6 +44,10 @@
#include "mongo/db/query/sbe_stage_builder_eval_frame.h"
#include "mongo/db/query/stage_types.h"
+namespace mongo::projection_ast {
+class Projection;
+}
+
namespace mongo::stage_builder {
std::unique_ptr<sbe::EExpression> makeUnaryOp(sbe::EPrimUnary::Op unaryOp,
@@ -524,19 +528,20 @@ std::unique_ptr<sbe::EExpression> makeLocalBind(sbe::value::FrameIdGenerator* fr
return sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(innerExpr));
}
-
-std::tuple<sbe::value::SlotId, sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>
-makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
- sbe::value::SlotId seekKeySlot,
- sbe::value::SlotId snapshotIdSlot,
- sbe::value::SlotId indexIdSlot,
- sbe::value::SlotId indexKeySlot,
- sbe::value::SlotId indexKeyPatternSlot,
- const CollectionPtr& collToFetch,
- StringMap<const IndexAccessMethod*> iamMap,
- PlanNodeId planNodeId,
- sbe::value::SlotVector slotsToForward,
- sbe::value::SlotIdGenerator& slotIdGenerator);
+std::unique_ptr<sbe::PlanStage> makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage,
+ sbe::value::SlotId resultSlot,
+ sbe::value::SlotId recordIdSlot,
+ std::vector<std::string> fields,
+ sbe::value::SlotVector fieldSlots,
+ sbe::value::SlotId seekKeySlot,
+ sbe::value::SlotId snapshotIdSlot,
+ sbe::value::SlotId indexIdSlot,
+ sbe::value::SlotId indexKeySlot,
+ sbe::value::SlotId indexKeyPatternSlot,
+ const CollectionPtr& collToFetch,
+ StringMap<const IndexAccessMethod*> iamMap,
+ PlanNodeId planNodeId,
+ sbe::value::SlotVector slotsToForward);
/**
* Trees generated by 'generateFilter' maintain state during execution. There are two types of state
@@ -1010,4 +1015,60 @@ std::unique_ptr<sbe::PlanStage> rehydrateIndexKey(std::unique_ptr<sbe::PlanStage
const sbe::value::SlotVector& indexKeySlots,
sbe::value::SlotId resultSlot);
+IndexKeyPatternTreeNode buildPatternTree(const projection_ast::Projection& projection);
+
+/**
+ * This method retrieves the values of the specified top-level fields ('fields') from 'resultSlot'
+ * and stores the values into slots.
+ *
+ * This method returns a pair containing: (1) the updated SBE plan stage tree and; (2) a vector of
+ * the slots ('slots') containing the field values.
+ *
+ * The order of slots in 'slots' will match the order of fields in 'fields'.
+ */
+std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectTopLevelFields(
+ std::unique_ptr<sbe::PlanStage> stage,
+ const std::vector<std::string>& fields,
+ sbe::value::SlotId resultSlot,
+ PlanNodeId nodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator);
+
+/**
+ * This method projects the constant value Nothing to multiple slots (the specific number of slots
+ * being specified by parameter 'n').
+ *
+ * This method returns a pair containing: (1) the updated SBE plan stage tree and; (2) a vector of
+ * slots ('slots') containing Nothing.
+ *
+ * The number of slots in 'slots' will always be equal to parameter 'n'.
+ */
+std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectNothingToSlots(
+ std::unique_ptr<sbe::PlanStage> stage,
+ size_t n,
+ PlanNodeId nodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator);
+
+template <typename T, typename FuncT>
+std::vector<T> filterVector(std::vector<T> vec, FuncT fn) {
+ std::vector<T> result;
+ std::copy_if(std::make_move_iterator(vec.begin()),
+ std::make_move_iterator(vec.end()),
+ std::back_inserter(result),
+ fn);
+ return result;
+}
+
+template <typename T, typename FuncT>
+std::pair<std::vector<T>, std::vector<T>> splitVector(std::vector<T> vec, FuncT fn) {
+ std::pair<std::vector<T>, std::vector<T>> result;
+ for (size_t i = 0; i < vec.size(); ++i) {
+ if (fn(vec[i])) {
+ result.first.emplace_back(std::move(vec[i]));
+ } else {
+ result.second.emplace_back(std::move(vec[i]));
+ }
+ }
+ return result;
+}
+
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder_index_scan.cpp b/src/mongo/db/query/sbe_stage_builder_index_scan.cpp
index 51d37ceccf4..bfc52d055c5 100644
--- a/src/mongo/db/query/sbe_stage_builder_index_scan.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_index_scan.cpp
@@ -696,7 +696,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan(
PlanYieldPolicy* yieldPolicy,
StringMap<const IndexAccessMethod*>* iamMap,
bool needsCorruptionCheck) {
-
auto indexName = ixn->index.identifier.catalogName;
auto descriptor = collection->getIndexCatalog()->findIndexByName(state.opCtx, indexName);
tassert(5483200,
@@ -855,8 +854,16 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan(
stage = outputStage.extractStage(ixn->nodeId());
}
- outputs.setIndexKeySlots(makeIndexKeyOutputSlotsMatchingParentReqs(
- ixn->index.keyPattern, originalIndexKeyBitset, indexKeyBitset, indexKeySlots));
+ size_t i = 0;
+ size_t slotIdx = 0;
+ for (const auto& elt : ixn->index.keyPattern) {
+ StringData name = elt.fieldNameStringData();
+ if (indexKeyBitset.test(i)) {
+ outputs.set(std::make_pair(PlanStageSlots::kKey, name), indexKeySlots[slotIdx]);
+ ++slotIdx;
+ }
+ ++i;
+ }
return {std::move(stage), std::move(outputs)};
}
@@ -981,8 +988,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith
// Whenever possible we should prefer building simplified single interval index scan plans in
// order to get the best performance.
if (canGenerateSingleIntervalIndexScan(ixn->iets)) {
- auto makeSlot = [&](const bool cond,
- const StringData slotKey) -> boost::optional<sbe::value::SlotId> {
+ auto makeSlot =
+ [&](const bool cond,
+ const PlanStageSlots::Name slotKey) -> boost::optional<sbe::value::SlotId> {
if (!cond)
return boost::none;
@@ -1024,10 +1032,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith
auto optimizedIndexScanSlots = optimizedIndexKeySlots;
auto branchOutputSlots = outputIndexKeySlots;
- auto makeSlotsForThenElseBranches =
- [&](const bool cond,
- const StringData slotKey) -> std::tuple<boost::optional<sbe::value::SlotId>,
- boost::optional<sbe::value::SlotId>> {
+ auto makeSlotsForThenElseBranches = [&](const bool cond, const PlanStageSlots::Name slotKey)
+ -> std::tuple<boost::optional<sbe::value::SlotId>,
+ boost::optional<sbe::value::SlotId>> {
if (!cond)
return {boost::none, boost::none};
@@ -1145,9 +1152,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith
stage = outputStage.extractStage(ixn->nodeId());
}
- outputs.setIndexKeySlots(makeIndexKeyOutputSlotsMatchingParentReqs(
- ixn->index.keyPattern, originalIndexKeyBitset, indexKeyBitset, outputIndexKeySlots));
-
state.data->indexBoundsEvaluationInfos.emplace_back(
IndexBoundsEvaluationInfo{ixn->index,
accessMethod->getSortedDataInterface()->getKeyStringVersion(),
@@ -1156,6 +1160,17 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith
std::move(ixn->iets),
std::move(parameterizedScanSlots)});
+ size_t i = 0;
+ size_t slotIdx = 0;
+ for (const auto& elt : ixn->index.keyPattern) {
+ StringData name = elt.fieldNameStringData();
+ if (indexKeyBitset.test(i)) {
+ outputs.set(std::make_pair(PlanStageSlots::kKey, name), outputIndexKeySlots[slotIdx]);
+ ++slotIdx;
+ }
+ ++i;
+ }
+
return {std::move(stage), std::move(outputs)};
}
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder_index_scan.h b/src/mongo/db/query/sbe_stage_builder_index_scan.h
index 29949cc1169..cded81563ac 100644
--- a/src/mongo/db/query/sbe_stage_builder_index_scan.h
+++ b/src/mongo/db/query/sbe_stage_builder_index_scan.h
@@ -48,14 +48,10 @@ using IndexIntervals =
std::vector<std::pair<std::unique_ptr<KeyString::Value>, std::unique_ptr<KeyString::Value>>>;
/**
- * This method generates an SBE plan stage tree implementing an index scan. It returns a tuple
- * containing: (1) a slot produced by the index scan that holds the record ID ('recordIdSlot');
- * (2) a slot vector produced by the index scan which hold parts of the index key ('indexKeySlots');
- * and (3) the SBE plan stage tree. 'indexKeySlots' will only contain slots for the parts of the
- * index key specified by the 'indexKeysToInclude' bitset.
- *
- * If the caller provides a slot ID for the 'returnKeySlot' parameter, this method will populate
- * the specified slot with the rehydrated index key for each record.
+ * This method returns a pair containing: (1) an SBE plan stage tree implementing an index scan;
+ * and (2) a PlanStageSlots object containing a kRecordId slot, possibly some other kMeta slots,
+ * and slots produced by the index scan that correspond to parts of the index key specified by
+ * the 'indexKeyBitset' bitset.
*/
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan(
StageBuilderState& state,
@@ -87,40 +83,12 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> packIndexIntervalsInSbeArray(
/**
* Constructs a generic multi-interval index scan. Depending on the intervals will either execute
- * the optimized or the generic index scan subplan. The generated subtree will have
- * the following form:
+ * the optimized or the generic index scan subplan.
*
- * branch {isGenericScanSlot} [recordIdSlot, resultSlot, ...]
- * then
- * filter {isRecordId(resultSlot)}
- * lspool sp1 [resultSlot] {!isRecordId(resultSlot)}
- * union [resultSlot]
- project [startKeySlot = anchorSlot, unusedVarSlot0 = Nothing, ...]
- * limit 1
- * coscan
- * [checkBoundsSlot]
- * nlj [] [seekKeySlot]
- * left
- * sspool sp1 [seekKeySlot]
- * right
- * chkbounds resultSlot recordIdSlot checkBoundsSlot
- * nlj [] [lowKeySlot]
- * left
- * project [lowKeySlot = seekKeySlot]
- * limit 1
- * coscan
- * right
- * ixseek lowKeySlot resultSlot recordIdSlot [] @coll @index
- * else
- * nlj [] [lowKeySlot, highKeySlot]
- * left
- * project [lowKeySlot = getField (unwindSlot, "l"),
- * highKeySlot = getField (unwindSlot, "h")]
- * unwind unwindSlot indexSlot boundsSlot false
- * limit 1
- * coscan
- * right
- * ixseek lowKeySlot highKeySlot recordIdSlot [] @coll @index
+ * This method returns a pair containing: (1) an SBE plan stage tree implementing a generic multi-
+ * interval index scan; and (2) a PlanStageSlots object containing a kRecordId slot, possibly some
+ * other kMeta slots, and slots produced by the index scan that correspond to parts of the index
+ * key specified by the 'indexKeyBitset' bitset.
*/
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWithDynamicBounds(
StageBuilderState& state,
diff --git a/src/mongo/db/query/sbe_stage_builder_lookup.cpp b/src/mongo/db/query/sbe_stage_builder_lookup.cpp
index 775090c9cdd..b4c431f26f5 100644
--- a/src/mongo/db/query/sbe_stage_builder_lookup.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_lookup.cpp
@@ -920,17 +920,21 @@ std::pair<SlotId, std::unique_ptr<sbe::PlanStage>> buildIndexJoinLookupStage(
// stored in 'foreignRecordSlot'. We also pass in 'snapshotIdSlot', 'indexIdSlot',
// 'indexKeySlot' and 'indexKeyPatternSlot' to perform index consistency check during the
// seek.
- auto [foreignRecordSlot, __, scanNljStage] = makeLoopJoinForFetch(std::move(ixScanNljStage),
- foreignRecordIdSlot,
- snapshotIdSlot,
- indexIdSlot,
- indexKeySlot,
- indexKeyPatternSlot,
- foreignColl,
- iamMap,
- nodeId,
- makeSV() /* slotsToForward */,
- slotIdGenerator);
+ auto foreignRecordSlot = slotIdGenerator.generate();
+ auto scanNljStage = makeLoopJoinForFetch(std::move(ixScanNljStage),
+ foreignRecordSlot,
+ slotIdGenerator.generate() /* unused recordId slot */,
+ std::vector<std::string>{},
+ makeSV(),
+ foreignRecordIdSlot,
+ snapshotIdSlot,
+ indexIdSlot,
+ indexKeySlot,
+ indexKeyPatternSlot,
+ foreignColl,
+ iamMap,
+ nodeId,
+ makeSV() /* slotsToForward */);
// 'buildForeignMatches()' filters the foreign records, returned by the index scan, to match
// those in 'localKeysSetSlot'. This is necessary because some values are encoded with the same
diff --git a/src/mongo/db/repl/SConscript b/src/mongo/db/repl/SConscript
index dd0aac54126..ad8be47af0a 100644
--- a/src/mongo/db/repl/SConscript
+++ b/src/mongo/db/repl/SConscript
@@ -1488,6 +1488,7 @@ env.Library(
],
LIBDEPS=[
'$BUILD_DIR/mongo/client/fetcher',
+ '$BUILD_DIR/mongo/executor/async_rpc',
'primary_only_service',
'tenant_migration_access_blocker',
'tenant_migration_statistics',
@@ -1717,6 +1718,7 @@ if wiredtiger:
'tenant_migration_recipient_access_blocker_test.cpp',
'tenant_migration_recipient_entry_helpers_test.cpp',
'tenant_migration_recipient_service_test.cpp',
+ 'tenant_migration_recipient_service_shard_merge_test.cpp',
],
LIBDEPS=[
'$BUILD_DIR/mongo/bson/mutable/mutable_bson',
diff --git a/src/mongo/db/repl/dbcheck.cpp b/src/mongo/db/repl/dbcheck.cpp
index aa295ce9ca7..0b0be849a5d 100644
--- a/src/mongo/db/repl/dbcheck.cpp
+++ b/src/mongo/db/repl/dbcheck.cpp
@@ -364,14 +364,26 @@ Status dbCheckBatchOnSecondary(OperationContext* opCtx,
// Set up the hasher,
boost::optional<DbCheckHasher> hasher;
try {
- auto lockMode = MODE_S;
- if (entry.getReadTimestamp()) {
- lockMode = MODE_IS;
- opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kProvided,
- entry.getReadTimestamp());
+ // We may not have a read timestamp if the dbCheck command was run on an older version of
+ // the server with snapshotRead:false. Since we don't implement this feature, we'll log an
+ // error about skipping the batch to ensure an operator notices.
+ if (!entry.getReadTimestamp().has_value()) {
+ auto logEntry =
+ dbCheckErrorHealthLogEntry(entry.getNss(),
+ "dbCheck failed",
+ OplogEntriesEnum::Batch,
+ Status{ErrorCodes::Error(6769502),
+ "no readTimestamp in oplog entry. Ensure dbCheck "
+ "command is not using snapshotRead:false"},
+ entry.toBSON());
+ HealthLog::get(opCtx).log(*logEntry);
+ return Status::OK();
}
- AutoGetCollection coll(opCtx, entry.getNss(), lockMode);
+ opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kProvided,
+ entry.getReadTimestamp());
+
+ AutoGetCollection coll(opCtx, entry.getNss(), MODE_IS);
const auto& collection = coll.getCollection();
if (!collection) {
@@ -434,12 +446,11 @@ Status dbCheckOplogCommand(OperationContext* opCtx,
if (!opCtx->writesAreReplicated()) {
opTime = entry.getOpTime();
}
- auto type = OplogEntries_parse(IDLParserContext("type"), cmd.getStringField("type"));
- IDLParserContext ctx("o");
-
+ const auto type = OplogEntries_parse(IDLParserContext("type"), cmd.getStringField("type"));
+ const IDLParserContext ctx("o", false /*apiStrict*/, entry.getTid());
switch (type) {
case OplogEntriesEnum::Batch: {
- auto invocation = DbCheckOplogBatch::parse(ctx, cmd);
+ const auto invocation = DbCheckOplogBatch::parse(ctx, cmd);
return dbCheckBatchOnSecondary(opCtx, opTime, invocation);
}
case OplogEntriesEnum::Collection: {
diff --git a/src/mongo/db/repl/dbcheck.idl b/src/mongo/db/repl/dbcheck.idl
index 9d3b20b68f5..5ded9ce6daf 100644
--- a/src/mongo/db/repl/dbcheck.idl
+++ b/src/mongo/db/repl/dbcheck.idl
@@ -38,42 +38,6 @@ imports:
- "mongo/db/write_concern_options.idl"
server_parameters:
- dbCheckCollectionTryLockTimeoutMillis:
- description: 'Timeout to acquire the collection for processing a dbCheck batch. Each subsequent attempt doubles the timeout'
- set_at: [ startup, runtime ]
- cpp_vartype: 'AtomicWord<int>'
- cpp_varname: gDbCheckCollectionTryLockTimeoutMillis
- default: 10
- validator:
- gte: 1
- lte: 10000
- dbCheckCollectionTryLockMaxAttempts:
- description: 'Maximum number of attempts with backoff to acquire the collection lock for processing a dbCheck batch'
- set_at: [ startup, runtime ]
- cpp_vartype: 'AtomicWord<int>'
- cpp_varname: gDbCheckCollectionTryLockMaxAttempts
- default: 5
- validator:
- gte: 1
- lte: 20
- dbCheckCollectionTryLockMinBackoffMillis:
- description: 'Initial backoff on failure to acquire the collection lock for processing a dbCheck batch. Grows exponentially'
- set_at: [ startup, runtime ]
- cpp_vartype: 'AtomicWord<int>'
- cpp_varname: gDbCheckCollectionTryLockMinBackoffMillis
- default: 10
- validator:
- gte: 2
- lte: 60000
- dbCheckCollectionTryLockMaxBackoffMillis:
- description: 'Maximum exponential backoff on failure to acquire the collection lock for processing a dbCheck batch.'
- set_at: [ startup, runtime ]
- cpp_vartype: 'AtomicWord<int>'
- cpp_varname: gDbCheckCollectionTryLockMaxBackoffMillis
- default: 60000
- validator:
- gte: 20
- lte: 120000
dbCheckHealthLogEveryNBatches:
description: 'Emit an info-severity health log batch every N batches processed'
set_at: [ startup, runtime ]
diff --git a/src/mongo/db/repl/noop_writer.cpp b/src/mongo/db/repl/noop_writer.cpp
index 0b94f0dc60a..c34fc425030 100644
--- a/src/mongo/db/repl/noop_writer.cpp
+++ b/src/mongo/db/repl/noop_writer.cpp
@@ -137,9 +137,9 @@ Status NoopWriter::startWritingPeriodicNoops(OpTime lastKnownOpTime) {
_noopRunner =
std::make_unique<PeriodicNoopRunner>(_writeInterval, [this](OperationContext* opCtx) {
// Noop writes are critical for the cluster stability, so we mark it as having Immediate
- // priority. As a result it will skip both flow control and normal ticket acquisition.
- SetTicketAquisitionPriorityForLock priority(opCtx,
- AdmissionContext::Priority::kImmediate);
+ // priority. As a result it will skip both flow control and waiting for ticket
+ // acquisition.
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
_writeNoop(opCtx);
});
return Status::OK();
diff --git a/src/mongo/db/repl/oplog_applier_impl.cpp b/src/mongo/db/repl/oplog_applier_impl.cpp
index 6b1121b059a..f7f01d12988 100644
--- a/src/mongo/db/repl/oplog_applier_impl.cpp
+++ b/src/mongo/db/repl/oplog_applier_impl.cpp
@@ -317,9 +317,9 @@ void OplogApplierImpl::_run(OplogBuffer* oplogBuffer) {
OperationContext& opCtx = *opCtxPtr;
// The oplog applier is crucial for stability of the replica set. As a result we mark it as
- // having Immediate priority. This makes the operation skip ticket acquisition and flow
- // control.
- SetTicketAquisitionPriorityForLock priority(&opCtx, AdmissionContext::Priority::kImmediate);
+ // having Immediate priority. This makes the operation skip waiting for ticket acquisition
+ // and flow control.
+ SetAdmissionPriorityForLock priority(&opCtx, AdmissionContext::Priority::kImmediate);
// For pausing replication in tests.
if (MONGO_unlikely(rsSyncApplyStop.shouldFail())) {
@@ -432,9 +432,10 @@ void scheduleWritesToOplogAndChangeCollection(OperationContext* opCtx,
auto opCtx = cc().makeOperationContext();
// Oplog writes are crucial to the stability of the replica set. We mark the operations
- // as having Immediate priority so that it skips ticket acquisition and flow control.
- SetTicketAquisitionPriorityForLock priority(opCtx.get(),
- AdmissionContext::Priority::kImmediate);
+ // as having Immediate priority so that it skips waiting for ticket acquisition and flow
+ // control.
+ SetAdmissionPriorityForLock priority(opCtx.get(),
+ AdmissionContext::Priority::kImmediate);
UnreplicatedWritesBlock uwb(opCtx.get());
ShouldNotConflictWithSecondaryBatchApplicationBlock shouldNotConflictBlock(
@@ -589,10 +590,10 @@ StatusWith<OpTime> OplogApplierImpl::_applyOplogBatch(OperationContext* opCtx,
auto opCtx = cc().makeOperationContext();
// Applying an Oplog batch is crucial to the stability of the Replica Set. We
- // mark it as having Immediate priority so that it skips ticket acquisition and
- // flow control.
- SetTicketAquisitionPriorityForLock priority(
- opCtx.get(), AdmissionContext::Priority::kImmediate);
+ // mark it as having Immediate priority so that it skips waiting for ticket
+ // acquisition and flow control.
+ SetAdmissionPriorityForLock priority(opCtx.get(),
+ AdmissionContext::Priority::kImmediate);
opCtx->setEnforceConstraints(false);
diff --git a/src/mongo/db/repl/repl_set_request_votes.cpp b/src/mongo/db/repl/repl_set_request_votes.cpp
index f80027813f9..5f07c9e4d98 100644
--- a/src/mongo/db/repl/repl_set_request_votes.cpp
+++ b/src/mongo/db/repl/repl_set_request_votes.cpp
@@ -61,9 +61,9 @@ private:
uassertStatusOK(status);
// Operations that are part of Replica Set elections are crucial to the stability of the
- // cluster. Marking it as having Immediate priority will make it skip ticket acquisition and
- // Flow Control.
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
+ // cluster. Marking it as having Immediate priority will make it skip waiting for ticket
+ // acquisition and Flow Control.
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
ReplSetRequestVotesResponse response;
status = ReplicationCoordinator::get(opCtx)->processReplSetRequestVotes(
opCtx, parsedArgs, &response);
diff --git a/src/mongo/db/repl/replication_consistency_markers_impl.cpp b/src/mongo/db/repl/replication_consistency_markers_impl.cpp
index b06238cb0d9..1e44889baf1 100644
--- a/src/mongo/db/repl/replication_consistency_markers_impl.cpp
+++ b/src/mongo/db/repl/replication_consistency_markers_impl.cpp
@@ -418,12 +418,11 @@ ReplicationConsistencyMarkersImpl::refreshOplogTruncateAfterPointIfPrimary(
}
ON_BLOCK_EXIT([&] { opCtx->recoveryUnit()->setPrepareConflictBehavior(originalBehavior); });
- // Exempt storage ticket acquisition in order to avoid starving upstream requests waiting
- // for durability. SERVER-60682 is an example with more pending prepared transactions than
- // storage tickets; the transaction coordinator could not persist the decision and
- // had to unnecessarily wait for prepared transactions to expire to make forward progress.
- SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx,
- AdmissionContext::Priority::kImmediate);
+ // Exempt waiting for storage ticket acquisition in order to avoid starving upstream requests
+ // waiting for durability. SERVER-60682 is an example with more pending prepared transactions
+ // than storage tickets; the transaction coordinator could not persist the decision and had to
+ // unnecessarily wait for prepared transactions to expire to make forward progress.
+ SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate);
// The locks necessary to write to the oplog truncate after point's collection and read from the
// oplog collection must be taken up front so that the mutex can also be taken around both
diff --git a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
index eea20ed70ee..8fc86f695cf 100644
--- a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
+++ b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp
@@ -855,7 +855,7 @@ void ReplicationCoordinatorExternalStateImpl::_stopAsyncUpdatesOfAndClearOplogTr
// As opCtx does not expose a method to allow skipping flow control on purpose we mark the
// operation as having Immediate priority. This will skip flow control and ticket acquisition.
// It is fine to do this since the system is essentially shutting down at this point.
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
// Tell the system to stop updating the oplogTruncateAfterPoint asynchronously and to go
// back to using last applied to update repl's durable timestamp instead of the truncate
diff --git a/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp b/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp
index e50508e9f35..57959db018b 100644
--- a/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp
+++ b/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp
@@ -339,8 +339,7 @@ void ReplicationCoordinatorImpl::ElectionState::_writeLastVoteForMyElection(
// Any operation that occurs as part of an election process is critical to the operation of
// the cluster. We mark the operation as having Immediate priority to skip ticket
// acquisition and flow control.
- SetTicketAquisitionPriorityForLock priority(opCtx.get(),
- AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock priority(opCtx.get(), AdmissionContext::Priority::kImmediate);
LOGV2(6015300,
"Storing last vote document in local storage for my election",
diff --git a/src/mongo/db/repl/replication_coordinator_test_fixture.cpp b/src/mongo/db/repl/replication_coordinator_test_fixture.cpp
index cd6d0e60381..7fae176024f 100644
--- a/src/mongo/db/repl/replication_coordinator_test_fixture.cpp
+++ b/src/mongo/db/repl/replication_coordinator_test_fixture.cpp
@@ -438,7 +438,7 @@ void ReplCoordTest::simulateSuccessfulV1ElectionAt(Date_t electionTime) {
void ReplCoordTest::signalDrainComplete(OperationContext* opCtx) noexcept {
// Writes that occur in code paths that call signalDrainComplete are expected to have Immediate
// priority.
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate);
getExternalState()->setFirstOpTimeOfMyTerm(OpTime(Timestamp(1, 1), getReplCoord()->getTerm()));
getReplCoord()->signalDrainComplete(opCtx, getReplCoord()->getTerm());
}
diff --git a/src/mongo/db/repl/replication_info.cpp b/src/mongo/db/repl/replication_info.cpp
index ed400698279..020cd913348 100644
--- a/src/mongo/db/repl/replication_info.cpp
+++ b/src/mongo/db/repl/replication_info.cpp
@@ -316,6 +316,10 @@ public:
return true;
}
+ bool allowedWithSecurityToken() const final {
+ return true;
+ }
+
AllowedOnSecondary secondaryAllowed(ServiceContext*) const final {
return AllowedOnSecondary::kAlways;
}
diff --git a/src/mongo/db/repl/storage_interface_impl.cpp b/src/mongo/db/repl/storage_interface_impl.cpp
index 24cd1419970..b1697942c24 100644
--- a/src/mongo/db/repl/storage_interface_impl.cpp
+++ b/src/mongo/db/repl/storage_interface_impl.cpp
@@ -1483,9 +1483,8 @@ Status StorageInterfaceImpl::isAdminDbValid(OperationContext* opCtx) {
void StorageInterfaceImpl::waitForAllEarlierOplogWritesToBeVisible(OperationContext* opCtx,
bool primaryOnly) {
// Waiting for oplog writes to be visible in the oplog does not use any storage engine resources
- // and must skip ticket acquisition to avoid deadlocks with updating oplog visibility.
- SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx,
- AdmissionContext::Priority::kImmediate);
+ // and must not wait for ticket acquisition to avoid deadlocks with updating oplog visibility.
+ SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate);
AutoGetOplog oplogRead(opCtx, OplogAccessMode::kRead);
if (primaryOnly &&
@@ -1501,8 +1500,7 @@ void StorageInterfaceImpl::oplogDiskLocRegister(OperationContext* opCtx,
bool orderedCommit) {
// Setting the oplog visibility does not use any storage engine resources and must skip ticket
// acquisition to avoid deadlocks with updating oplog visibility.
- SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx,
- AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate);
AutoGetOplog oplogRead(opCtx, OplogAccessMode::kRead);
fassert(28557,
diff --git a/src/mongo/db/repl/tenant_migration_donor_service.cpp b/src/mongo/db/repl/tenant_migration_donor_service.cpp
index 9470ac2b036..920538fadd6 100644
--- a/src/mongo/db/repl/tenant_migration_donor_service.cpp
+++ b/src/mongo/db/repl/tenant_migration_donor_service.cpp
@@ -30,6 +30,7 @@
#include "mongo/db/repl/tenant_migration_donor_service.h"
+#include "mongo/client/async_remote_command_targeter_adapter.h"
#include "mongo/client/connection_string.h"
#include "mongo/client/replica_set_monitor.h"
#include "mongo/config.h"
@@ -49,6 +50,8 @@
#include "mongo/db/repl/tenant_migration_state_machine_gen.h"
#include "mongo/db/repl/tenant_migration_statistics.h"
#include "mongo/db/repl/wait_for_majority_service.h"
+#include "mongo/executor/async_rpc.h"
+#include "mongo/executor/async_rpc_retry_policy.h"
#include "mongo/executor/connection_pool.h"
#include "mongo/executor/network_interface_factory.h"
#include "mongo/logv2/log.h"
@@ -91,25 +94,50 @@ const ReadPreferenceSetting kPrimaryOnlyReadPreference(ReadPreference::PrimaryOn
const int kMaxRecipientKeyDocsFindAttempts = 10;
-bool shouldStopSendingRecipientForgetMigrationCommand(Status status) {
- return status.isOK() ||
- !(ErrorCodes::isRetriableError(status) || ErrorCodes::isNetworkTimeoutError(status) ||
- // Returned if findHost() is unable to target the recipient in 15 seconds, which may
- // happen after a failover.
- status == ErrorCodes::FailedToSatisfyReadPreference ||
- ErrorCodes::isInterruption(status));
-}
+/**
+ * Encapsulates the retry logic for sending the ForgetMigration command.
+ */
+class RecipientForgetMigrationRetryPolicy
+ : public async_rpc::RetryWithBackoffOnErrorCategories<ErrorCategory::RetriableError,
+ ErrorCategory::NetworkTimeoutError,
+ ErrorCategory::Interruption> {
+public:
+ using RetryWithBackoffOnErrorCategories::RetryWithBackoffOnErrorCategories;
+ bool recordAndEvaluateRetry(Status status) override {
+ if (status.isOK()) {
+ return false;
+ }
+ auto underlyingError = async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status);
+ // Returned if findHost() is unable to target the recipient in 15 seconds, which may
+ // happen after a failover.
+ return RetryWithBackoffOnErrorCategories::recordAndEvaluateRetry(underlyingError) ||
+ underlyingError == ErrorCodes::FailedToSatisfyReadPreference;
+ }
+};
-bool shouldStopSendingRecipientSyncDataCommand(Status status, MigrationProtocolEnum protocol) {
- if (status.isOK() || protocol == MigrationProtocolEnum::kShardMerge) {
- return true;
+/**
+ * Encapsulates the retry logic for sending the SyncData command.
+ */
+class RecipientSyncDataRetryPolicy
+ : public async_rpc::RetryWithBackoffOnErrorCategories<ErrorCategory::RetriableError,
+ ErrorCategory::NetworkTimeoutError> {
+public:
+ RecipientSyncDataRetryPolicy(MigrationProtocolEnum p, Backoff b)
+ : RetryWithBackoffOnErrorCategories(b), _protocol{p} {}
+
+ /** Returns true if we should retry sending SyncData given the error */
+ bool recordAndEvaluateRetry(Status status) {
+ if (_protocol == MigrationProtocolEnum::kShardMerge || status.isOK()) {
+ return false;
+ }
+ auto underlyingError = async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status);
+ return RetryWithBackoffOnErrorCategories::recordAndEvaluateRetry(status) ||
+ underlyingError == ErrorCodes::FailedToSatisfyReadPreference;
}
- return !(ErrorCodes::isRetriableError(status) || ErrorCodes::isNetworkTimeoutError(status) ||
- // Returned if findHost() is unable to target the recipient in 15 seconds, which may
- // happen after a failover.
- status == ErrorCodes::FailedToSatisfyReadPreference);
-}
+private:
+ MigrationProtocolEnum _protocol;
+};
bool shouldStopFetchingRecipientClusterTimeKeyDocs(Status status) {
return status.isOK() ||
@@ -744,90 +772,48 @@ ExecutorFuture<void> TenantMigrationDonorService::Instance::_waitForMajorityWrit
});
}
-ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendCommandToRecipient(
- std::shared_ptr<executor::ScopedTaskExecutor> executor,
- std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS,
- const BSONObj& cmdObj,
- const CancellationToken& token) {
- const bool isRecipientSyncDataCmd = cmdObj.hasField(RecipientSyncData::kCommandName);
- return AsyncTry(
- [this, self = shared_from_this(), executor, recipientTargeterRS, cmdObj, token] {
- return recipientTargeterRS->findHost(kPrimaryOnlyReadPreference, token)
- .thenRunOn(**executor)
- .then([this, self = shared_from_this(), executor, cmdObj, token](
- auto recipientHost) {
- executor::RemoteCommandRequest request(
- std::move(recipientHost),
- NamespaceString::kAdminDb.toString(),
- std::move(cmdObj),
- rpc::makeEmptyMetadata(),
- nullptr);
- request.sslMode = _sslMode;
-
- return (_recipientCmdExecutor)
- ->scheduleRemoteCommand(std::move(request), token)
- .then([this,
- self = shared_from_this()](const auto& response) -> Status {
- if (!response.isOK()) {
- return response.status;
- }
- auto commandStatus = getStatusFromCommandResult(response.data);
- commandStatus.addContext(
- "Tenant migration recipient command failed");
- return commandStatus;
- });
- });
- })
- .until([this, self = shared_from_this(), token, cmdObj, isRecipientSyncDataCmd](
- Status status) {
- if (isRecipientSyncDataCmd) {
- return shouldStopSendingRecipientSyncDataCommand(status, getProtocol());
- } else {
- // If the recipient command is not 'recipientSyncData', it must be
- // 'recipientForgetMigration'.
- invariant(cmdObj.hasField(RecipientForgetMigration::kCommandName));
- return shouldStopSendingRecipientForgetMigrationCommand(status);
- }
- })
- .withBackoffBetweenIterations(kExponentialBackoff)
- .on(**executor, token);
-}
-
ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientSyncDataCommand(
- std::shared_ptr<executor::ScopedTaskExecutor> executor,
+ std::shared_ptr<executor::ScopedTaskExecutor> exec,
std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS,
const CancellationToken& token) {
+ auto donorConnString =
+ repl::ReplicationCoordinator::get(_serviceContext)->getConfigConnectionString();
- const auto cmdObj = [&] {
- auto donorConnString =
- repl::ReplicationCoordinator::get(_serviceContext)->getConfigConnectionString();
-
- RecipientSyncData request;
- request.setDbName(NamespaceString::kAdminDb);
+ RecipientSyncData request;
+ request.setDbName(NamespaceString::kAdminDb);
- MigrationRecipientCommonData commonData(
- _migrationUuid, donorConnString.toString(), _readPreference);
- commonData.setRecipientCertificateForDonor(_recipientCertificateForDonor);
- if (_protocol == MigrationProtocolEnum::kMultitenantMigrations) {
- commonData.setTenantId(boost::optional<StringData>(_tenantId));
- }
+ MigrationRecipientCommonData commonData(
+ _migrationUuid, donorConnString.toString(), _readPreference);
+ commonData.setRecipientCertificateForDonor(_recipientCertificateForDonor);
+ if (_protocol == MigrationProtocolEnum::kMultitenantMigrations) {
+ commonData.setTenantId(boost::optional<StringData>(_tenantId));
+ }
- stdx::lock_guard<Latch> lg(_mutex);
- commonData.setProtocol(_protocol);
- request.setMigrationRecipientCommonData(commonData);
+ commonData.setProtocol(_protocol);
+ request.setMigrationRecipientCommonData(commonData);
+ {
+ stdx::lock_guard<Latch> lg(_mutex);
invariant(_stateDoc.getStartMigrationDonorTimestamp());
request.setStartMigrationDonorTimestamp(*_stateDoc.getStartMigrationDonorTimestamp());
request.setReturnAfterReachingDonorTimestamp(_stateDoc.getBlockTimestamp());
- return request.toBSON(BSONObj());
- }();
+ }
- return _sendCommandToRecipient(executor, recipientTargeterRS, cmdObj, token);
+ auto asyncTargeter = std::make_unique<async_rpc::AsyncRemoteCommandTargeterAdapter>(
+ kPrimaryOnlyReadPreference, recipientTargeterRS);
+ auto retryPolicy =
+ std::make_shared<RecipientSyncDataRetryPolicy>(getProtocol(), kExponentialBackoff);
+ auto cmdRes = async_rpc::sendCommand(
+ request, _serviceContext, std::move(asyncTargeter), **exec, token, retryPolicy);
+ return std::move(cmdRes).ignoreValue().onError([](Status status) {
+ return async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status).addContext(
+ "Tenant migration recipient command failed");
+ });
}
ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientForgetMigrationCommand(
- std::shared_ptr<executor::ScopedTaskExecutor> executor,
+ std::shared_ptr<executor::ScopedTaskExecutor> exec,
std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS,
const CancellationToken& token) {
@@ -847,7 +833,15 @@ ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientForget
commonData.setProtocol(_protocol);
request.setMigrationRecipientCommonData(commonData);
- return _sendCommandToRecipient(executor, recipientTargeterRS, request.toBSON(BSONObj()), token);
+ auto asyncTargeter = std::make_unique<async_rpc::AsyncRemoteCommandTargeterAdapter>(
+ kPrimaryOnlyReadPreference, recipientTargeterRS);
+ auto retryPolicy = std::make_shared<RecipientForgetMigrationRetryPolicy>(kExponentialBackoff);
+ auto cmdRes = async_rpc::sendCommand(
+ request, _serviceContext, std::move(asyncTargeter), **exec, token, retryPolicy);
+ return std::move(cmdRes).ignoreValue().onError([](Status status) {
+ return async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status).addContext(
+ "Tenant migration recipient command failed");
+ });
}
CancellationToken TenantMigrationDonorService::Instance::_initAbortMigrationSource(
diff --git a/src/mongo/db/repl/tenant_migration_donor_service.h b/src/mongo/db/repl/tenant_migration_donor_service.h
index c9f28ddcc07..f8b0db86e17 100644
--- a/src/mongo/db/repl/tenant_migration_donor_service.h
+++ b/src/mongo/db/repl/tenant_migration_donor_service.h
@@ -255,15 +255,6 @@ public:
const CancellationToken& token);
/**
- * Sends the given command to the recipient replica set.
- */
- ExecutorFuture<void> _sendCommandToRecipient(
- std::shared_ptr<executor::ScopedTaskExecutor> executor,
- std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS,
- const BSONObj& cmdObj,
- const CancellationToken& token);
-
- /**
* Sends the recipientSyncData command to the recipient replica set.
*/
ExecutorFuture<void> _sendRecipientSyncDataCommand(
diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.cpp b/src/mongo/db/repl/tenant_migration_recipient_service.cpp
index 57ec62819a1..c6c4d4d2958 100644
--- a/src/mongo/db/repl/tenant_migration_recipient_service.cpp
+++ b/src/mongo/db/repl/tenant_migration_recipient_service.cpp
@@ -76,6 +76,7 @@
#include "mongo/db/transaction/transaction_participant.h"
#include "mongo/db/vector_clock_mutable.h"
#include "mongo/db/write_concern_options.h"
+#include "mongo/executor/task_executor.h"
#include "mongo/logv2/log.h"
#include "mongo/rpc/get_status_from_command_result.h"
#include "mongo/util/assert_util.h"
@@ -93,6 +94,8 @@ const std::string kTTLIndexName = "TenantMigrationRecipientTTLIndex";
const Backoff kExponentialBackoff(Seconds(1), Milliseconds::max());
constexpr StringData kOplogBufferPrefix = "repl.migration.oplog_"_sd;
constexpr int kBackupCursorFileFetcherRetryAttempts = 10;
+constexpr int kCheckpointTsBackupCursorErrorCode = 6929900;
+constexpr int kCloseCursorBeforeOpenErrorCode = 50886;
NamespaceString getOplogBufferNs(const UUID& migrationUUID) {
return NamespaceString(NamespaceString::kConfigDb,
@@ -979,25 +982,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_killBackupCursor()
nullptr);
request.sslMode = _donorUri.getSSLMode();
- auto scheduleResult =
- (_recipientService->getInstanceCleanupExecutor())
- ->scheduleRemoteCommand(
- request, [](const executor::TaskExecutor::RemoteCommandCallbackArgs& args) {
- if (!args.response.isOK()) {
- LOGV2_WARNING(6113005,
- "killCursors command task failed",
- "error"_attr = redact(args.response.status));
- return;
- }
- auto status = getStatusFromCommandResult(args.response.data);
- if (status.isOK()) {
- LOGV2_INFO(6113415, "Killed backup cursor");
- } else {
- LOGV2_WARNING(6113006,
- "killCursors command failed",
- "error"_attr = redact(status));
- }
- });
+ const auto scheduleResult = _scheduleKillBackupCursorWithLock(
+ lk, _recipientService->getInstanceCleanupExecutor());
if (!scheduleResult.isOK()) {
LOGV2_WARNING(6113004,
"Failed to run killCursors command on backup cursor",
@@ -1009,13 +995,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_killBackupCursor()
SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
const CancellationToken& token) {
- stdx::lock_guard lk(_mutex);
- LOGV2_DEBUG(6113000,
- 1,
- "Trying to open backup cursor on donor primary",
- "migrationId"_attr = _stateDoc.getId(),
- "donorConnectionString"_attr = _stateDoc.getDonorConnectionString());
- const auto cmdObj = [] {
+
+ const auto aggregateCommandRequestObj = [] {
AggregateCommandRequest aggRequest(
NamespaceString::makeCollectionlessAggregateNSS(NamespaceString::kAdminDb),
{BSON("$backupCursor" << BSONObj())});
@@ -1024,11 +1005,18 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
return aggRequest.toBSON(BSONObj());
}();
- auto startMigrationDonorTimestamp = _stateDoc.getStartMigrationDonorTimestamp();
+ stdx::lock_guard lk(_mutex);
+ LOGV2_DEBUG(6113000,
+ 1,
+ "Trying to open backup cursor on donor primary",
+ "migrationId"_attr = _stateDoc.getId(),
+ "donorConnectionString"_attr = _stateDoc.getDonorConnectionString());
+
+ const auto startMigrationDonorTimestamp = _stateDoc.getStartMigrationDonorTimestamp();
auto fetchStatus = std::make_shared<boost::optional<Status>>();
auto uniqueMetadataInfo = std::make_unique<boost::optional<shard_merge_utils::MetadataInfo>>();
- auto fetcherCallback =
+ const auto fetcherCallback =
[
this,
self = shared_from_this(),
@@ -1043,8 +1031,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
uassertStatusOK(dataStatus);
uassert(ErrorCodes::CallbackCanceled, "backup cursor interrupted", !token.isCanceled());
- auto uniqueOpCtx = cc().makeOperationContext();
- auto opCtx = uniqueOpCtx.get();
+ const auto uniqueOpCtx = cc().makeOperationContext();
+ const auto opCtx = uniqueOpCtx.get();
const auto& data = dataStatus.getValue();
for (const BSONObj& doc : data.documents) {
@@ -1059,14 +1047,6 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
"backupCursorId"_attr = data.cursorId,
"backupCursorCheckpointTimestamp"_attr = checkpointTimestamp);
- // This ensures that the recipient won’t receive any 2 phase index build donor
- // oplog entries during the migration. We also have a check in the tenant oplog
- // applier to detect such oplog entries. Adding a check here helps us to detect
- // the problem earlier.
- uassert(6929900,
- "backupCursorCheckpointTimestamp should be greater than or equal to "
- "startMigrationDonorTimestamp",
- checkpointTimestamp >= startMigrationDonorTimestamp);
{
stdx::lock_guard lk(_mutex);
stdx::lock_guard<TenantMigrationSharedData> sharedDatalk(*_sharedData);
@@ -1075,6 +1055,15 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
BackupCursorInfo{data.cursorId, data.nss, checkpointTimestamp});
}
+ // This ensures that the recipient won’t receive any 2 phase index build donor
+ // oplog entries during the migration. We also have a check in the tenant oplog
+ // applier to detect such oplog entries. Adding a check here helps us to detect
+ // the problem earlier.
+ uassert(kCheckpointTsBackupCursorErrorCode,
+ "backupCursorCheckpointTimestamp should be greater than or equal to "
+ "startMigrationDonorTimestamp",
+ checkpointTimestamp >= startMigrationDonorTimestamp);
+
invariant(metadataInfoPtr && !*metadataInfoPtr);
(*metadataInfoPtr) = shard_merge_utils::MetadataInfo::constructMetadataInfo(
getMigrationUUID(), _client->getServerAddress(), metadata);
@@ -1133,10 +1122,10 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
};
_donorFilenameBackupCursorFileFetcher = std::make_unique<Fetcher>(
- (**_scopedExecutor).get(),
+ _backupCursorExecutor.get(),
_client->getServerHostAndPort(),
NamespaceString::kAdminDb.toString(),
- cmdObj,
+ aggregateCommandRequestObj,
fetcherCallback,
ReadPreferenceSetting(ReadPreference::PrimaryPreferred).toContainingBSON(),
executor::RemoteCommandRequest::kNoTimeout, /* aggregateTimeout */
@@ -1160,6 +1149,35 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor(
.semi();
}
+StatusWith<executor::TaskExecutor::CallbackHandle>
+TenantMigrationRecipientService::Instance::_scheduleKillBackupCursorWithLock(
+ WithLock lk, std::shared_ptr<executor::TaskExecutor> executor) {
+ auto& donorBackupCursorInfo = _getDonorBackupCursorInfo(lk);
+ executor::RemoteCommandRequest killCursorsRequest(
+ _client->getServerHostAndPort(),
+ donorBackupCursorInfo.nss.db().toString(),
+ BSON("killCursors" << donorBackupCursorInfo.nss.coll().toString() << "cursors"
+ << BSON_ARRAY(donorBackupCursorInfo.cursorId)),
+ nullptr);
+ killCursorsRequest.sslMode = _donorUri.getSSLMode();
+
+ return executor->scheduleRemoteCommand(
+ killCursorsRequest, [](const executor::TaskExecutor::RemoteCommandCallbackArgs& args) {
+ if (!args.response.isOK()) {
+ LOGV2_WARNING(6113005,
+ "killCursors command task failed",
+ "error"_attr = redact(args.response.status));
+ return;
+ }
+ auto status = getStatusFromCommandResult(args.response.data);
+ if (status.isOK()) {
+ LOGV2_INFO(6113415, "Killed backup cursor");
+ } else {
+ LOGV2_WARNING(6113006, "killCursors command failed", "error"_attr = redact(status));
+ }
+ });
+}
+
SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursorWithRetry(
const CancellationToken& token) {
return AsyncTry([this, self = shared_from_this(), token] { return _openBackupCursor(token); })
@@ -1169,8 +1187,20 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursorWit
"Retrying backup cursor creation after transient error",
"migrationId"_attr = getMigrationUUID(),
"status"_attr = status);
- // A checkpoint took place while opening a backup cursor. We
- // should retry and *not* cancel migration.
+
+ return false;
+ } else if (status.code() == kCheckpointTsBackupCursorErrorCode ||
+ status.code() == kCloseCursorBeforeOpenErrorCode) {
+ LOGV2_INFO(6955100,
+ "Closing backup cursor and retrying after getting retryable error",
+ "migrationId"_attr = getMigrationUUID(),
+ "status"_attr = status);
+
+ stdx::lock_guard lk(_mutex);
+ const auto scheduleResult =
+ _scheduleKillBackupCursorWithLock(lk, _backupCursorExecutor);
+ uassertStatusOK(scheduleResult);
+
return false;
}
@@ -1199,7 +1229,7 @@ void TenantMigrationRecipientService::Instance::_keepBackupCursorAlive(
auto& donorBackupCursorInfo = _getDonorBackupCursorInfo(lk);
_backupCursorKeepAliveFuture =
shard_merge_utils::keepBackupCursorAlive(_backupCursorKeepAliveCancellation,
- **_scopedExecutor,
+ _backupCursorExecutor,
_client->getServerHostAndPort(),
donorBackupCursorInfo.cursorId,
donorBackupCursorInfo.nss);
@@ -2703,6 +2733,7 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::run(
std::shared_ptr<executor::ScopedTaskExecutor> executor,
const CancellationToken& token) noexcept {
_scopedExecutor = executor;
+ _backupCursorExecutor = **_scopedExecutor;
auto scopedOutstandingMigrationCounter =
TenantMigrationStatistics::get(_serviceContext)->getScopedOutstandingReceivingCount();
diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.h b/src/mongo/db/repl/tenant_migration_recipient_service.h
index 22de9a00fd1..a71f29139d9 100644
--- a/src/mongo/db/repl/tenant_migration_recipient_service.h
+++ b/src/mongo/db/repl/tenant_migration_recipient_service.h
@@ -213,6 +213,15 @@ public:
private:
friend class TenantMigrationRecipientServiceTest;
+ friend class TenantMigrationRecipientServiceShardMergeTest;
+
+ /**
+ * Only used for testing. Allows setting a custom task executor for backup cursor fetcher.
+ */
+ void setBackupCursorFetcherExecutor_forTest(
+ std::shared_ptr<executor::TaskExecutor> taskExecutor) {
+ _backupCursorExecutor = taskExecutor;
+ }
const NamespaceString _stateDocumentsNS =
NamespaceString::kTenantMigrationRecipientsNamespace;
@@ -605,6 +614,13 @@ public:
SemiFuture<TenantOplogApplier::OpTimePair> _migrateUsingShardMergeProtocol(
const CancellationToken& token);
+ /*
+ * Send the killBackupCursor command to the remote in order to close the backup cursor
+ * connection on the donor.
+ */
+ StatusWith<executor::TaskExecutor::CallbackHandle> _scheduleKillBackupCursorWithLock(
+ WithLock lk, std::shared_ptr<executor::TaskExecutor> executor);
+
mutable Mutex _mutex = MONGO_MAKE_LATCH("TenantMigrationRecipientService::_mutex");
// All member variables are labeled with one of the following codes indicating the
@@ -618,6 +634,7 @@ public:
ServiceContext* const _serviceContext;
const TenantMigrationRecipientService* const _recipientService; // (R) (not owned)
std::shared_ptr<executor::ScopedTaskExecutor> _scopedExecutor; // (M)
+ std::shared_ptr<executor::TaskExecutor> _backupCursorExecutor; // (M)
TenantMigrationRecipientDocument _stateDoc; // (M)
// This data is provided in the initial state doc and never changes. We keep copies to
diff --git a/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp b/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp
new file mode 100644
index 00000000000..be6f0bbc799
--- /dev/null
+++ b/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp
@@ -0,0 +1,593 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include <fstream>
+#include <memory>
+
+#include "mongo/client/connpool.h"
+#include "mongo/client/replica_set_monitor.h"
+#include "mongo/client/replica_set_monitor_protocol_test_util.h"
+#include "mongo/client/streamable_replica_set_monitor_for_testing.h"
+#include "mongo/config.h"
+#include "mongo/db/client.h"
+#include "mongo/db/db_raii.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/feature_compatibility_version_document_gen.h"
+#include "mongo/db/op_observer/op_observer_impl.h"
+#include "mongo/db/op_observer/op_observer_registry.h"
+#include "mongo/db/op_observer/oplog_writer_impl.h"
+#include "mongo/db/repl/drop_pending_collection_reaper.h"
+#include "mongo/db/repl/oplog.h"
+#include "mongo/db/repl/oplog_buffer_collection.h"
+#include "mongo/db/repl/oplog_fetcher_mock.h"
+#include "mongo/db/repl/primary_only_service.h"
+#include "mongo/db/repl/primary_only_service_op_observer.h"
+#include "mongo/db/repl/replication_coordinator_mock.h"
+#include "mongo/db/repl/storage_interface_impl.h"
+#include "mongo/db/repl/tenant_migration_recipient_entry_helpers.h"
+#include "mongo/db/repl/tenant_migration_recipient_service.h"
+#include "mongo/db/repl/tenant_migration_state_machine_gen.h"
+#include "mongo/db/repl/wait_for_majority_service.h"
+#include "mongo/db/service_context_d_test_fixture.h"
+#include "mongo/db/session/session_txn_record_gen.h"
+#include "mongo/db/storage/backup_cursor_hooks.h"
+#include "mongo/dbtests/mock/mock_conn_registry.h"
+#include "mongo/dbtests/mock/mock_replica_set.h"
+#include "mongo/executor/mock_network_fixture.h"
+#include "mongo/executor/network_interface.h"
+#include "mongo/executor/network_interface_mock.h"
+#include "mongo/executor/thread_pool_mock.h"
+#include "mongo/executor/thread_pool_task_executor.h"
+#include "mongo/executor/thread_pool_task_executor_test_fixture.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/logv2/log.h"
+#include "mongo/rpc/metadata/egress_metadata_hook_list.h"
+#include "mongo/transport/transport_layer_manager.h"
+#include "mongo/transport/transport_layer_mock.h"
+#include "mongo/unittest/log_test.h"
+#include "mongo/unittest/unittest.h"
+#include "mongo/util/clock_source_mock.h"
+#include "mongo/util/concurrency/thread_pool.h"
+#include "mongo/util/fail_point.h"
+#include "mongo/util/future.h"
+#include "mongo/util/net/ssl_util.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
+
+
+namespace mongo {
+namespace repl {
+
+namespace {
+constexpr std::int32_t stopFailPointErrorCode = 4880402;
+const Timestamp kDefaultStartMigrationTimestamp(1, 1);
+
+OplogEntry makeOplogEntry(OpTime opTime,
+ OpTypeEnum opType,
+ NamespaceString nss,
+ const boost::optional<UUID>& uuid,
+ BSONObj o,
+ boost::optional<BSONObj> o2) {
+ return {DurableOplogEntry(opTime, // optime
+ opType, // opType
+ nss, // namespace
+ uuid, // uuid
+ boost::none, // fromMigrate
+ OplogEntry::kOplogVersion, // version
+ o, // o
+ o2, // o2
+ {}, // sessionInfo
+ boost::none, // upsert
+ Date_t(), // wall clock time
+ {}, // statement ids
+ boost::none, // optime of previous write within same transaction
+ boost::none, // pre-image optime
+ boost::none, // post-image optime
+ boost::none, // ShardId of resharding recipient
+ boost::none, // _id
+ boost::none)}; // needsRetryImage
+}
+
+} // namespace
+
+class TenantMigrationRecipientServiceShardMergeTest : public ServiceContextMongoDTest {
+public:
+ class stopFailPointEnableBlock : public FailPointEnableBlock {
+ public:
+ explicit stopFailPointEnableBlock(StringData failPointName,
+ std::int32_t error = stopFailPointErrorCode)
+ : FailPointEnableBlock(failPointName,
+ BSON("action"
+ << "stop"
+ << "stopErrorCode" << error)) {}
+ };
+
+ void setUp() override {
+ ServiceContextMongoDTest::setUp();
+ auto serviceContext = getServiceContext();
+
+ // Fake replSet just for creating consistent URI for monitor
+ MockReplicaSet replSet("donorSet", 1, true /* hasPrimary */, true /* dollarPrefixHosts */);
+ _rsmMonitor.setup(replSet.getURI());
+
+ ConnectionString::setConnectionHook(mongo::MockConnRegistry::get()->getConnStrHook());
+
+ WaitForMajorityService::get(serviceContext).startup(serviceContext);
+
+ // Automatically mark the state doc garbage collectable after data sync completion.
+ globalFailPointRegistry()
+ .find("autoRecipientForgetMigration")
+ ->setMode(FailPoint::alwaysOn);
+
+ {
+ auto opCtx = cc().makeOperationContext();
+ auto replCoord = std::make_unique<ReplicationCoordinatorMock>(serviceContext);
+ ReplicationCoordinator::set(serviceContext, std::move(replCoord));
+
+ repl::createOplog(opCtx.get());
+ {
+ Lock::GlobalWrite lk(opCtx.get());
+ OldClientContext ctx(opCtx.get(), NamespaceString::kRsOplogNamespace);
+ tenant_migration_util::createOplogViewForTenantMigrations(opCtx.get(), ctx.db());
+ }
+
+ // Need real (non-mock) storage for the oplog buffer.
+ StorageInterface::set(serviceContext, std::make_unique<StorageInterfaceImpl>());
+
+ // The DropPendingCollectionReaper is required to drop the oplog buffer collection.
+ repl::DropPendingCollectionReaper::set(
+ serviceContext,
+ std::make_unique<repl::DropPendingCollectionReaper>(
+ StorageInterface::get(serviceContext)));
+
+ // Set up OpObserver so that repl::logOp() will store the oplog entry's optime in
+ // ReplClientInfo.
+ OpObserverRegistry* opObserverRegistry =
+ dynamic_cast<OpObserverRegistry*>(serviceContext->getOpObserver());
+ opObserverRegistry->addObserver(
+ std::make_unique<OpObserverImpl>(std::make_unique<OplogWriterImpl>()));
+ opObserverRegistry->addObserver(
+ std::make_unique<PrimaryOnlyServiceOpObserver>(serviceContext));
+
+ _registry = repl::PrimaryOnlyServiceRegistry::get(getServiceContext());
+ std::unique_ptr<TenantMigrationRecipientService> service =
+ std::make_unique<TenantMigrationRecipientService>(getServiceContext());
+ _registry->registerService(std::move(service));
+ _registry->onStartup(opCtx.get());
+ }
+ stepUp();
+
+ _service = _registry->lookupServiceByName(
+ TenantMigrationRecipientService::kTenantMigrationRecipientServiceName);
+ ASSERT(_service);
+
+ // MockReplicaSet uses custom connection string which does not support auth.
+ auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth");
+ authFp->setMode(FailPoint::alwaysOn);
+
+ // Set the sslMode to allowSSL to avoid validation error.
+ sslGlobalParams.sslMode.store(SSLParams::SSLMode_allowSSL);
+ // Skipped unless tested explicitly, as we will not receive an FCV document from the donor
+ // in these unittests without (unsightly) intervention.
+ auto compFp = globalFailPointRegistry().find("skipComparingRecipientAndDonorFCV");
+ compFp->setMode(FailPoint::alwaysOn);
+
+ // Skip fetching retryable writes, as we will test this logic entirely in integration
+ // tests.
+ auto fetchRetryableWritesFp =
+ globalFailPointRegistry().find("skipFetchingRetryableWritesEntriesBeforeStartOpTime");
+ fetchRetryableWritesFp->setMode(FailPoint::alwaysOn);
+
+ // Skip fetching committed transactions, as we will test this logic entirely in integration
+ // tests.
+ auto fetchCommittedTransactionsFp =
+ globalFailPointRegistry().find("skipFetchingCommittedTransactions");
+ fetchCommittedTransactionsFp->setMode(FailPoint::alwaysOn);
+
+ // setup mock networking that will be use to mock the backup cursor traffic.
+ auto net = std::make_unique<executor::NetworkInterfaceMock>();
+ _net = net.get();
+
+ executor::ThreadPoolMock::Options dbThreadPoolOptions;
+ dbThreadPoolOptions.onCreateThread = []() { Client::initThread("FetchMockTaskExecutor"); };
+
+ auto pool = std::make_unique<executor::ThreadPoolMock>(_net, 1, dbThreadPoolOptions);
+ _threadpoolTaskExecutor =
+ std::make_shared<executor::ThreadPoolTaskExecutor>(std::move(pool), std::move(net));
+ _threadpoolTaskExecutor->startup();
+ }
+
+ void tearDown() override {
+ _threadpoolTaskExecutor->shutdown();
+ _threadpoolTaskExecutor->join();
+
+ auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth");
+ authFp->setMode(FailPoint::off);
+
+ // Unset the sslMode.
+ sslGlobalParams.sslMode.store(SSLParams::SSLMode_disabled);
+
+ WaitForMajorityService::get(getServiceContext()).shutDown();
+
+ _registry->onShutdown();
+ _service = nullptr;
+
+ StorageInterface::set(getServiceContext(), {});
+
+ // Clearing the connection pool is necessary when doing tests which use the
+ // ReplicaSetMonitor. See src/mongo/dbtests/mock/mock_replica_set.h for details.
+ ScopedDbConnection::clearPool();
+ ReplicaSetMonitorProtocolTestUtil::resetRSMProtocol();
+ ServiceContextMongoDTest::tearDown();
+ }
+
+ void stepDown() {
+ ASSERT_OK(ReplicationCoordinator::get(getServiceContext())
+ ->setFollowerMode(MemberState::RS_SECONDARY));
+ _registry->onStepDown();
+ }
+
+ void stepUp() {
+ auto opCtx = cc().makeOperationContext();
+ auto replCoord = ReplicationCoordinator::get(getServiceContext());
+
+ // Advance term
+ _term++;
+
+ ASSERT_OK(replCoord->setFollowerMode(MemberState::RS_PRIMARY));
+ ASSERT_OK(replCoord->updateTerm(opCtx.get(), _term));
+ replCoord->setMyLastAppliedOpTimeAndWallTime(
+ OpTimeAndWallTime(OpTime(Timestamp(1, 1), _term), Date_t()));
+
+ _registry->onStepUpComplete(opCtx.get(), _term);
+ }
+
+protected:
+ TenantMigrationRecipientServiceShardMergeTest()
+ : ServiceContextMongoDTest(Options{}.useMockClock(true)) {}
+
+ PrimaryOnlyServiceRegistry* _registry;
+ PrimaryOnlyService* _service;
+ long long _term = 0;
+
+ bool _collCreated = false;
+ size_t _numSecondaryIndexesCreated{0};
+ size_t _numDocsInserted{0};
+
+ const TenantMigrationPEMPayload kRecipientPEMPayload = [&] {
+ std::ifstream infile("jstests/libs/client.pem");
+ std::string buf((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
+
+ auto swCertificateBlob =
+ ssl_util::findPEMBlob(buf, "CERTIFICATE"_sd, 0 /* position */, false /* allowEmpty */);
+ ASSERT_TRUE(swCertificateBlob.isOK());
+
+ auto swPrivateKeyBlob =
+ ssl_util::findPEMBlob(buf, "PRIVATE KEY"_sd, 0 /* position */, false /* allowEmpty */);
+ ASSERT_TRUE(swPrivateKeyBlob.isOK());
+
+ return TenantMigrationPEMPayload{swCertificateBlob.getValue().toString(),
+ swPrivateKeyBlob.getValue().toString()};
+ }();
+
+ void checkStateDocPersisted(OperationContext* opCtx,
+ const TenantMigrationRecipientService::Instance* instance) {
+ auto memoryStateDoc = getStateDoc(instance);
+ auto persistedStateDocWithStatus =
+ tenantMigrationRecipientEntryHelpers::getStateDoc(opCtx, memoryStateDoc.getId());
+ ASSERT_OK(persistedStateDocWithStatus.getStatus());
+ ASSERT_BSONOBJ_EQ(memoryStateDoc.toBSON(), persistedStateDocWithStatus.getValue().toBSON());
+ }
+ void insertToNodes(MockReplicaSet* replSet,
+ const std::string& nss,
+ BSONObj obj,
+ const std::vector<HostAndPort>& hosts) {
+ for (const auto& host : hosts) {
+ replSet->getNode(host.toString())->insert(nss, obj);
+ }
+ }
+
+ void clearCollection(MockReplicaSet* replSet,
+ const std::string& nss,
+ const std::vector<HostAndPort>& hosts) {
+ for (const auto& host : hosts) {
+ replSet->getNode(host.toString())->remove(nss, BSONObj{} /*filter*/);
+ }
+ }
+
+ void insertTopOfOplog(MockReplicaSet* replSet,
+ const OpTime& topOfOplogOpTime,
+ const std::vector<HostAndPort> hosts = {}) {
+ const auto targetHosts = hosts.empty() ? replSet->getHosts() : hosts;
+ // The MockRemoteDBService does not actually implement the database, so to make our
+ // find work correctly we must make sure there's only one document to find.
+ clearCollection(replSet, NamespaceString::kRsOplogNamespace.ns(), targetHosts);
+ insertToNodes(replSet,
+ NamespaceString::kRsOplogNamespace.ns(),
+ makeOplogEntry(topOfOplogOpTime,
+ OpTypeEnum::kNoop,
+ {} /* namespace */,
+ boost::none /* uuid */,
+ BSONObj() /* o */,
+ boost::none /* o2 */)
+ .getEntry()
+ .toBSON(),
+ targetHosts);
+ }
+
+ // Accessors to class private members
+ DBClientConnection* getClient(const TenantMigrationRecipientService::Instance* instance) const {
+ return instance->_client.get();
+ }
+
+ const TenantMigrationRecipientDocument& getStateDoc(
+ const TenantMigrationRecipientService::Instance* instance) const {
+ return instance->_stateDoc;
+ }
+
+ sdam::MockTopologyManager* getTopologyManager() {
+ return _rsmMonitor.getTopologyManager();
+ }
+
+ ClockSource* clock() {
+ return &_clkSource;
+ }
+
+ executor::NetworkInterfaceMock* getNet() {
+ return _net;
+ }
+
+ executor::NetworkInterfaceMock* _net = nullptr;
+ std::shared_ptr<executor::TaskExecutor> _threadpoolTaskExecutor;
+
+ void setInstanceBackupCursorFetcherExecutor(
+ std::shared_ptr<TenantMigrationRecipientService::Instance> instance) {
+ instance->setBackupCursorFetcherExecutor_forTest(_threadpoolTaskExecutor);
+ }
+
+private:
+ ClockSourceMock _clkSource;
+
+ unittest::MinimumLoggedSeverityGuard _replicationSeverityGuard{
+ logv2::LogComponent::kReplication, logv2::LogSeverity::Debug(1)};
+ unittest::MinimumLoggedSeverityGuard _tenantMigrationSeverityGuard{
+ logv2::LogComponent::kTenantMigration, logv2::LogSeverity::Debug(1)};
+
+ StreamableReplicaSetMonitorForTesting _rsmMonitor;
+ RAIIServerParameterControllerForTest _findHostTimeout{"defaultFindReplicaSetHostTimeoutMS", 10};
+};
+
+#ifdef MONGO_CONFIG_SSL
+
+void waitForReadyRequest(executor::NetworkInterfaceMock* net) {
+ while (!net->hasReadyRequests()) {
+ net->advanceTime(net->now() + Milliseconds{1});
+ }
+}
+
+BSONObj createEmptyCursorResponse(const NamespaceString& nss, CursorId backupCursorId) {
+ return BSON(
+ "cursor" << BSON("nextBatch" << BSONArray() << "id" << backupCursorId << "ns" << nss.ns())
+ << "ok" << 1.0);
+}
+
+BSONObj createBackupCursorResponse(const Timestamp& checkpointTimestamp,
+ const NamespaceString& nss,
+ CursorId backupCursorId) {
+ const UUID backupId =
+ UUID(uassertStatusOK(UUID::parse(("2b068e03-5961-4d8e-b47a-d1c8cbd4b835"))));
+ StringData remoteDbPath = "/data/db/job0/mongorunner/test-1";
+ BSONObjBuilder cursor;
+ BSONArrayBuilder batch(cursor.subarrayStart("firstBatch"));
+ auto metaData = BSON("backupId" << backupId << "checkpointTimestamp" << checkpointTimestamp
+ << "dbpath" << remoteDbPath);
+ batch.append(BSON("metadata" << metaData));
+
+ batch.done();
+ cursor.append("id", backupCursorId);
+ cursor.append("ns", nss.ns());
+ BSONObjBuilder backupCursorReply;
+ backupCursorReply.append("cursor", cursor.obj());
+ backupCursorReply.append("ok", 1.0);
+ return backupCursorReply.obj();
+}
+
+void sendReponseToExpectedRequest(const BSONObj& backupCursorResponse,
+ const std::string& expectedRequestFieldName,
+ executor::NetworkInterfaceMock* net) {
+ auto noi = net->getNextReadyRequest();
+ auto request = noi->getRequest();
+ ASSERT_EQUALS(expectedRequestFieldName, request.cmdObj.firstElementFieldNameStringData());
+ net->scheduleSuccessfulResponse(
+ noi, executor::RemoteCommandResponse(backupCursorResponse, Milliseconds()));
+ net->runReadyNetworkOperations();
+}
+
+BSONObj createServerAggregateReply() {
+ return CursorResponse(
+ NamespaceString::makeCollectionlessAggregateNSS(NamespaceString::kAdminDb),
+ 0 /* cursorId */,
+ {BSON("byteOffset" << 0 << "endOfFile" << true << "data"
+ << BSONBinData(0, 0, BinDataGeneral))})
+ .toBSONAsInitialResponse();
+}
+
+TEST_F(TenantMigrationRecipientServiceShardMergeTest, OpenBackupCursorSuccessfully) {
+ stopFailPointEnableBlock fp("fpBeforeAdvancingStableTimestamp");
+ const UUID migrationUUID = UUID::gen();
+ const CursorId backupCursorId = 12345;
+ const NamespaceString aggregateNs = NamespaceString("admin.$cmd.aggregate");
+
+ auto taskFp = globalFailPointRegistry().find("hangBeforeTaskCompletion");
+ auto initialTimesEntered = taskFp->setMode(FailPoint::alwaysOn);
+
+ MockReplicaSet replSet("donorSet", 3, true /* hasPrimary */, true /* dollarPrefixHosts */);
+ getTopologyManager()->setTopologyDescription(replSet.getTopologyDescription(clock()));
+ insertTopOfOplog(&replSet, OpTime(Timestamp(5, 1), 1));
+
+ // Mock the aggregate response from the donor.
+ MockRemoteDBServer* const _donorServer =
+ mongo::MockConnRegistry::get()->getMockRemoteDBServer(replSet.getPrimary());
+ _donorServer->setCommandReply("aggregate", createServerAggregateReply());
+
+ TenantMigrationRecipientDocument initialStateDocument(
+ migrationUUID,
+ replSet.getConnectionString(),
+ "tenantA",
+ kDefaultStartMigrationTimestamp,
+ ReadPreferenceSetting(ReadPreference::PrimaryOnly));
+ initialStateDocument.setProtocol(MigrationProtocolEnum::kShardMerge);
+ initialStateDocument.setRecipientCertificateForDonor(kRecipientPEMPayload);
+
+ auto opCtx = makeOperationContext();
+ std::shared_ptr<TenantMigrationRecipientService::Instance> instance;
+ {
+ auto fp = globalFailPointRegistry().find("pauseBeforeRunTenantMigrationRecipientInstance");
+ auto initialTimesEntered = fp->setMode(FailPoint::alwaysOn);
+ instance = TenantMigrationRecipientService::Instance::getOrCreate(
+ opCtx.get(), _service, initialStateDocument.toBSON());
+ ASSERT(instance.get());
+ fp->waitForTimesEntered(initialTimesEntered + 1);
+ setInstanceBackupCursorFetcherExecutor(instance);
+ instance->setCreateOplogFetcherFn_forTest(std::make_unique<CreateOplogFetcherMockFn>());
+ fp->setMode(FailPoint::off);
+ }
+
+ {
+ auto net = getNet();
+ executor::NetworkInterfaceMock::InNetworkGuard guard(net);
+ waitForReadyRequest(net);
+ // Mocking the aggregate command network response of the backup cursor in order to have
+ // data to parse.
+ sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp,
+ aggregateNs,
+ backupCursorId),
+ "aggregate",
+ net);
+ sendReponseToExpectedRequest(
+ createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net);
+ sendReponseToExpectedRequest(
+ createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net);
+ }
+
+ taskFp->waitForTimesEntered(initialTimesEntered + 1);
+
+ checkStateDocPersisted(opCtx.get(), instance.get());
+
+ taskFp->setMode(FailPoint::off);
+
+ ASSERT_EQ(stopFailPointErrorCode, instance->getDataSyncCompletionFuture().getNoThrow().code());
+ ASSERT_OK(instance->getForgetMigrationDurableFuture().getNoThrow());
+}
+
+TEST_F(TenantMigrationRecipientServiceShardMergeTest, OpenBackupCursorAndRetriesDueToTs) {
+ stopFailPointEnableBlock fp("fpBeforeAdvancingStableTimestamp");
+ const UUID migrationUUID = UUID::gen();
+ const CursorId backupCursorId = 12345;
+ const NamespaceString aggregateNs = NamespaceString("admin.$cmd.aggregate");
+
+ auto taskFp = globalFailPointRegistry().find("hangBeforeTaskCompletion");
+ auto initialTimesEntered = taskFp->setMode(FailPoint::alwaysOn);
+
+ MockReplicaSet replSet("donorSet", 3, true /* hasPrimary */, true /* dollarPrefixHosts */);
+ getTopologyManager()->setTopologyDescription(replSet.getTopologyDescription(clock()));
+ insertTopOfOplog(&replSet, OpTime(Timestamp(5, 1), 1));
+
+ // Mock the aggregate response from the donor.
+ MockRemoteDBServer* const _donorServer =
+ mongo::MockConnRegistry::get()->getMockRemoteDBServer(replSet.getPrimary());
+ _donorServer->setCommandReply("aggregate", createServerAggregateReply());
+
+ TenantMigrationRecipientDocument initialStateDocument(
+ migrationUUID,
+ replSet.getConnectionString(),
+ "tenantA",
+ kDefaultStartMigrationTimestamp,
+ ReadPreferenceSetting(ReadPreference::PrimaryOnly));
+ initialStateDocument.setProtocol(MigrationProtocolEnum::kShardMerge);
+ initialStateDocument.setRecipientCertificateForDonor(kRecipientPEMPayload);
+
+ auto opCtx = makeOperationContext();
+ std::shared_ptr<TenantMigrationRecipientService::Instance> instance;
+ {
+ auto fp = globalFailPointRegistry().find("pauseBeforeRunTenantMigrationRecipientInstance");
+ auto initialTimesEntered = fp->setMode(FailPoint::alwaysOn);
+ instance = TenantMigrationRecipientService::Instance::getOrCreate(
+ opCtx.get(), _service, initialStateDocument.toBSON());
+ ASSERT(instance.get());
+ fp->waitForTimesEntered(initialTimesEntered + 1);
+ setInstanceBackupCursorFetcherExecutor(instance);
+ instance->setCreateOplogFetcherFn_forTest(std::make_unique<CreateOplogFetcherMockFn>());
+ fp->setMode(FailPoint::off);
+ }
+
+ {
+ auto net = getNet();
+ executor::NetworkInterfaceMock::InNetworkGuard guard(net);
+ waitForReadyRequest(net);
+
+ // Mocking the aggregate command network response of the backup cursor in order to have data
+ // to parse. In this case we pass a timestamp that is inferior to the
+ // startMigrationTimestamp which will cause a retry. We then provide a correct timestamp in
+ // the next response and succeed.
+ sendReponseToExpectedRequest(
+ createBackupCursorResponse(Timestamp(0, 0), aggregateNs, backupCursorId),
+ "aggregate",
+ net);
+ sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp,
+ aggregateNs,
+ backupCursorId),
+ "killCursors",
+ net);
+ sendReponseToExpectedRequest(
+ createEmptyCursorResponse(aggregateNs, backupCursorId), "killCursors", net);
+ sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp,
+ aggregateNs,
+ backupCursorId),
+ "aggregate",
+ net);
+ sendReponseToExpectedRequest(
+ createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net);
+ sendReponseToExpectedRequest(
+ createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net);
+ }
+
+ taskFp->waitForTimesEntered(initialTimesEntered + 1);
+
+ checkStateDocPersisted(opCtx.get(), instance.get());
+
+ taskFp->setMode(FailPoint::off);
+
+ ASSERT_EQ(stopFailPointErrorCode, instance->getDataSyncCompletionFuture().getNoThrow().code());
+ ASSERT_OK(instance->getForgetMigrationDurableFuture().getNoThrow());
+}
+
+#endif
+} // namespace repl
+} // namespace mongo
diff --git a/src/mongo/db/s/SConscript b/src/mongo/db/s/SConscript
index fa54218910d..3248c1b296d 100644
--- a/src/mongo/db/s/SConscript
+++ b/src/mongo/db/s/SConscript
@@ -11,7 +11,6 @@ env = env.Clone()
env.Library(
target='sharding_api_d',
source=[
- 'balancer_stats_registry.cpp',
'collection_metadata.cpp',
'collection_sharding_state_factory_standalone.cpp',
'collection_sharding_state.cpp',
@@ -36,14 +35,25 @@ env.Library(
LIBDEPS_PRIVATE=[
'$BUILD_DIR/mongo/db/catalog/index_catalog',
'$BUILD_DIR/mongo/db/concurrency/lock_manager',
- '$BUILD_DIR/mongo/db/dbdirectclient',
- '$BUILD_DIR/mongo/db/repl/replica_set_aware_service',
'$BUILD_DIR/mongo/db/server_base',
'$BUILD_DIR/mongo/db/write_block_bypass',
],
)
env.Library(
+ target='balancer_stats_registry',
+ source=[
+ 'balancer_stats_registry.cpp',
+ ],
+ LIBDEPS_PRIVATE=[
+ '$BUILD_DIR/mongo/db/dbdirectclient',
+ '$BUILD_DIR/mongo/db/repl/replica_set_aware_service',
+ '$BUILD_DIR/mongo/db/shard_role',
+ '$BUILD_DIR/mongo/s/grid',
+ ],
+)
+
+env.Library(
target='sharding_catalog',
source=[
'global_index_ddl_util.cpp',
@@ -64,6 +74,22 @@ env.Library(
)
env.Library(
+ target="query_analysis_writer",
+ source=[
+ "query_analysis_writer.cpp",
+ ],
+ LIBDEPS_PRIVATE=[
+ "$BUILD_DIR/mongo/db/dbdirectclient",
+ '$BUILD_DIR/mongo/db/ops/write_ops_parsers',
+ '$BUILD_DIR/mongo/db/server_base',
+ '$BUILD_DIR/mongo/db/service_context',
+ '$BUILD_DIR/mongo/db/shard_role',
+ '$BUILD_DIR/mongo/idl/idl_parser',
+ '$BUILD_DIR/mongo/s/analyze_shard_key_common',
+ ],
+)
+
+env.Library(
target='sharding_runtime_d',
source=[
'active_migrations_registry.cpp',
@@ -228,6 +254,8 @@ env.Library(
'$BUILD_DIR/mongo/db/transaction/transaction_operations',
'$BUILD_DIR/mongo/s/common_s',
'$BUILD_DIR/mongo/util/future_util',
+ 'balancer_stats_registry',
+ 'query_analysis_writer',
'sharding_catalog',
'sharding_logging',
],
@@ -550,6 +578,7 @@ env.Library(
'$BUILD_DIR/mongo/s/sharding_initialization',
'$BUILD_DIR/mongo/s/sharding_router_api',
'$BUILD_DIR/mongo/s/startup_initialization',
+ 'balancer_stats_registry',
'forwardable_operation_metadata',
'sharding_catalog',
'sharding_logging',
@@ -653,6 +682,7 @@ env.CppUnitTest(
'op_observer_sharding_test.cpp',
'operation_sharding_state_test.cpp',
'persistent_task_queue_test.cpp',
+ 'query_analysis_writer_test.cpp',
'range_deleter_service_test.cpp',
'range_deleter_service_test_util.cpp',
'range_deleter_service_op_observer_test.cpp',
@@ -734,6 +764,7 @@ env.CppUnitTest(
'$BUILD_DIR/mongo/executor/thread_pool_task_executor_test_fixture',
'$BUILD_DIR/mongo/s/catalog/sharding_catalog_client_mock',
'$BUILD_DIR/mongo/s/sharding_router_test_fixture',
+ 'query_analysis_writer',
'shard_server_test_fixture',
'sharding_catalog',
'sharding_commands_d',
diff --git a/src/mongo/db/s/balancer/balancer.cpp b/src/mongo/db/s/balancer/balancer.cpp
index f1c96cc97cf..9103ec7f04c 100644
--- a/src/mongo/db/s/balancer/balancer.cpp
+++ b/src/mongo/db/s/balancer/balancer.cpp
@@ -281,7 +281,7 @@ Balancer::Balancer()
std::make_unique<BalancerChunkSelectionPolicyImpl>(_clusterStats.get(), _random)),
_commandScheduler(std::make_unique<BalancerCommandsSchedulerImpl>()),
_defragmentationPolicy(std::make_unique<BalancerDefragmentationPolicyImpl>(
- _clusterStats.get(), _random, [this]() { _onActionsStreamPolicyStateUpdate(); })),
+ _clusterStats.get(), [this]() { _onActionsStreamPolicyStateUpdate(); })),
_clusterChunksResizePolicy(std::make_unique<ClusterChunksResizePolicyImpl>(
[this] { _onActionsStreamPolicyStateUpdate(); })) {}
@@ -1085,7 +1085,7 @@ int Balancer::_moveChunks(OperationContext* opCtx,
opCtx, migrateInfo.uuid, repl::ReadConcernLevel::kMajorityReadConcern);
ShardingCatalogManager::get(opCtx)->splitOrMarkJumbo(
- opCtx, collection.getNss(), migrateInfo.minKey);
+ opCtx, collection.getNss(), migrateInfo.minKey, migrateInfo.getMaxChunkSizeBytes());
continue;
}
@@ -1151,7 +1151,12 @@ BalancerCollectionStatusResponse Balancer::getBalancerStatusForNs(OperationConte
uasserted(ErrorCodes::NamespaceNotSharded, "Collection unsharded or undefined");
}
- const auto maxChunkSizeMB = getMaxChunkSizeMB(opCtx, coll);
+
+ const auto maxChunkSizeBytes = getMaxChunkSizeBytes(opCtx, coll);
+ double maxChunkSizeMB = (double)maxChunkSizeBytes / (1024 * 1024);
+ // Keep only 2 decimal digits to return a readable value
+ maxChunkSizeMB = std::ceil(maxChunkSizeMB * 100.0) / 100.0;
+
BalancerCollectionStatusResponse response(maxChunkSizeMB, true /*balancerCompliant*/);
auto setViolationOnResponse = [&response](const StringData& reason,
const boost::optional<BSONObj>& details =
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
index 1c5a88a6100..c8d51184bc3 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp
@@ -398,7 +398,8 @@ public:
std::move(collectionChunks),
std::move(shardInfos),
std::move(collectionZones),
- smallChunkSizeThresholdBytes));
+ smallChunkSizeThresholdBytes,
+ maxChunkSizeBytes));
}
DefragmentationPhaseEnum getType() const override {
@@ -462,7 +463,7 @@ public:
auto smallChunkVersion = getShardVersion(opCtx, nextSmallChunk->shard, _nss);
_outstandingMigrations.emplace_back(nextSmallChunk, targetSibling);
return _outstandingMigrations.back().asMigrateInfo(
- _uuid, _nss, smallChunkVersion.placementVersion());
+ _uuid, _nss, smallChunkVersion.placementVersion(), _maxChunkSizeBytes);
}
return boost::none;
@@ -701,7 +702,8 @@ private:
MigrateInfo asMigrateInfo(const UUID& collUuid,
const NamespaceString& nss,
- const ChunkVersion& version) const {
+ const ChunkVersion& version,
+ uint64_t maxChunkSizeBytes) const {
return MigrateInfo(chunkToMergeWith->shard,
chunkToMove->shard,
nss,
@@ -709,7 +711,8 @@ private:
chunkToMove->range.getMin(),
chunkToMove->range.getMax(),
version,
- ForceJumbo::kForceBalancer);
+ ForceJumbo::kForceBalancer,
+ maxChunkSizeBytes);
}
ChunkRange asMergedRange() const {
@@ -774,6 +777,8 @@ private:
const int64_t _smallChunkSizeThresholdBytes;
+ const uint64_t _maxChunkSizeBytes;
+
bool _aborted{false};
DefragmentationPhaseEnum _nextPhase{DefragmentationPhaseEnum::kMergeChunks};
@@ -783,7 +788,8 @@ private:
std::vector<ChunkType>&& collectionChunks,
stdx::unordered_map<ShardId, ShardInfo>&& shardInfos,
ZoneInfo&& collectionZones,
- uint64_t smallChunkSizeThresholdBytes)
+ uint64_t smallChunkSizeThresholdBytes,
+ uint64_t maxChunkSizeBytes)
: _nss(nss),
_uuid(uuid),
_collectionChunks(),
@@ -794,7 +800,8 @@ private:
_actionableMerges(),
_outstandingMerges(),
_zoneInfo(std::move(collectionZones)),
- _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes) {
+ _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes),
+ _maxChunkSizeBytes(maxChunkSizeBytes) {
// Load the collection routing table in a std::list to ease later manipulation
for (auto&& chunk : collectionChunks) {
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
index bc41346ca7f..bab4a6c58cf 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h
@@ -72,9 +72,10 @@ class BalancerDefragmentationPolicyImpl : public BalancerDefragmentationPolicy {
public:
BalancerDefragmentationPolicyImpl(ClusterStatistics* clusterStats,
- BalancerRandomSource& random,
const std::function<void()>& onStateUpdated)
- : _clusterStats(clusterStats), _random(random), _onStateUpdated(onStateUpdated) {}
+ : _clusterStats(clusterStats),
+ _random(std::random_device{}()),
+ _onStateUpdated(onStateUpdated) {}
~BalancerDefragmentationPolicyImpl() {}
@@ -145,7 +146,7 @@ private:
ClusterStatistics* const _clusterStats;
- BalancerRandomSource& _random;
+ BalancerRandomSource _random;
const std::function<void()> _onStateUpdated;
diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
index 8646b5d965a..e9370bbb4d3 100644
--- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
+++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp
@@ -29,7 +29,6 @@
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/s/balancer/balancer_defragmentation_policy_impl.h"
-#include "mongo/db/s/balancer/balancer_random.h"
#include "mongo/db/s/balancer/cluster_statistics_mock.h"
#include "mongo/db/s/config/config_server_test_fixture.h"
#include "mongo/idl/server_parameter_test_util.h"
@@ -75,9 +74,7 @@ protected:
const std::function<void()> onDefragmentationStateUpdated = [] {};
BalancerDefragmentationPolicyTest()
- : _clusterStats(),
- _random(std::random_device{}()),
- _defragmentationPolicy(&_clusterStats, _random, onDefragmentationStateUpdated) {}
+ : _clusterStats(), _defragmentationPolicy(&_clusterStats, onDefragmentationStateUpdated) {}
CollectionType setupCollectionWithPhase(
const std::vector<ChunkType>& chunkList,
@@ -137,7 +134,6 @@ protected:
}
ClusterStatisticsMock _clusterStats;
- BalancerRandomSource _random;
BalancerDefragmentationPolicyImpl _defragmentationPolicy;
ShardStatistics buildShardStats(ShardId id,
diff --git a/src/mongo/db/s/balancer/balancer_policy.cpp b/src/mongo/db/s/balancer/balancer_policy.cpp
index 4e0ff4c3ccb..a8ad96ab2c7 100644
--- a/src/mongo/db/s/balancer/balancer_policy.cpp
+++ b/src/mongo/db/s/balancer/balancer_policy.cpp
@@ -36,6 +36,7 @@
#include "mongo/db/s/balancer/type_migration.h"
#include "mongo/logv2/log.h"
+#include "mongo/s/balancer_configuration.h"
#include "mongo/s/catalog/type_shard.h"
#include "mongo/s/catalog/type_tags.h"
#include "mongo/s/grid.h"
@@ -450,7 +451,16 @@ MigrateInfosWithReason BalancerPolicy::balance(
}
invariant(to != stat.shardId);
- migrations.emplace_back(to, distribution.nss(), chunk, ForceJumbo::kForceBalancer);
+
+ auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> {
+ if (collDataSizeInfo.has_value()) {
+ return collDataSizeInfo->maxChunkSizeBytes;
+ }
+ return boost::none;
+ }();
+
+ migrations.emplace_back(
+ to, distribution.nss(), chunk, ForceJumbo::kForceBalancer, maxChunkSizeBytes);
if (firstReason == MigrationReason::none) {
firstReason = MigrationReason::drain;
}
@@ -513,11 +523,20 @@ MigrateInfosWithReason BalancerPolicy::balance(
}
invariant(to != stat.shardId);
+
+ auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> {
+ if (collDataSizeInfo.has_value()) {
+ return collDataSizeInfo->maxChunkSizeBytes;
+ }
+ return boost::none;
+ }();
+
migrations.emplace_back(to,
distribution.nss(),
chunk,
forceJumbo ? ForceJumbo::kForceBalancer
- : ForceJumbo::kDoNotForce);
+ : ForceJumbo::kDoNotForce,
+ maxChunkSizeBytes);
if (firstReason == MigrationReason::none) {
firstReason = MigrationReason::zoneViolation;
}
@@ -796,7 +815,8 @@ string ZoneRange::toString() const {
MigrateInfo::MigrateInfo(const ShardId& a_to,
const NamespaceString& a_nss,
const ChunkType& a_chunk,
- const ForceJumbo a_forceJumbo)
+ const ForceJumbo a_forceJumbo,
+ boost::optional<int64_t> maxChunkSizeBytes)
: nss(a_nss), uuid(a_chunk.getCollectionUUID()) {
invariant(a_to.isValid());
@@ -807,6 +827,7 @@ MigrateInfo::MigrateInfo(const ShardId& a_to,
maxKey = a_chunk.getMax();
version = a_chunk.getVersion();
forceJumbo = a_forceJumbo;
+ optMaxChunkSizeBytes = maxChunkSizeBytes;
}
MigrateInfo::MigrateInfo(const ShardId& a_to,
@@ -858,6 +879,10 @@ string MigrateInfo::toString() const {
<< ", to " << to;
}
+boost::optional<int64_t> MigrateInfo::getMaxChunkSizeBytes() const {
+ return optMaxChunkSizeBytes;
+}
+
SplitInfo::SplitInfo(const ShardId& inShardId,
const NamespaceString& inNss,
const ChunkVersion& inCollectionVersion,
diff --git a/src/mongo/db/s/balancer/balancer_policy.h b/src/mongo/db/s/balancer/balancer_policy.h
index bd464ca9499..a9a1f307310 100644
--- a/src/mongo/db/s/balancer/balancer_policy.h
+++ b/src/mongo/db/s/balancer/balancer_policy.h
@@ -60,7 +60,8 @@ struct MigrateInfo {
MigrateInfo(const ShardId& a_to,
const NamespaceString& a_nss,
const ChunkType& a_chunk,
- ForceJumbo a_forceJumbo);
+ ForceJumbo a_forceJumbo,
+ boost::optional<int64_t> maxChunkSizeBytes = boost::none);
MigrateInfo(const ShardId& a_to,
const ShardId& a_from,
@@ -78,6 +79,8 @@ struct MigrateInfo {
std::string toString() const;
+ boost::optional<int64_t> getMaxChunkSizeBytes() const;
+
NamespaceString nss;
UUID uuid;
ShardId to;
diff --git a/src/mongo/db/s/balancer_stats_registry.cpp b/src/mongo/db/s/balancer_stats_registry.cpp
index bfd789acb52..0f11a578c23 100644
--- a/src/mongo/db/s/balancer_stats_registry.cpp
+++ b/src/mongo/db/s/balancer_stats_registry.cpp
@@ -29,6 +29,7 @@
#include "mongo/db/s/balancer_stats_registry.h"
+#include "mongo/db/catalog_raii.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/pipeline/aggregate_command_gen.h"
#include "mongo/db/repl/replication_coordinator.h"
@@ -58,23 +59,6 @@ ThreadPool::Options makeDefaultThreadPoolOptions() {
}
} // namespace
-ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx)
- // TODO SERVER-62491 Use system tenantId for DBLock
- : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX),
- _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_X) {}
-
-// Take DB and Collection lock in mode IX as well as collection UUID lock to serialize with
-// operations that take the above version of the ScopedRangeDeleterLock such as FCV downgrade and
-// BalancerStatsRegistry initialization.
-ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid)
- // TODO SERVER-62491 Use system tenantId for DBLock
- : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX),
- _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_IX),
- _collectionUuidLock(Lock::ResourceLock(
- opCtx,
- ResourceId(RESOURCE_MUTEX, "RangeDeleterCollLock::" + collectionUuid.toString()),
- MODE_X)) {}
-
const ReplicaSetAwareServiceRegistry::Registerer<BalancerStatsRegistry>
balancerStatsRegistryRegisterer("BalancerStatsRegistry");
@@ -131,9 +115,13 @@ void BalancerStatsRegistry::initializeAsync(OperationContext* opCtx) {
LOGV2_DEBUG(6419601, 2, "Initializing BalancerStatsRegistry");
try {
- // Lock the range deleter to prevent
- // concurrent modifications of orphans count
- ScopedRangeDeleterLock rangeDeleterLock(opCtx);
+ // Lock the range deleter to prevent concurrent modifications of orphans count
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_S);
+ // The collection lock is needed to serialize with direct writes to
+ // config.rangeDeletions
+ AutoGetCollection rangeDeletionLock(
+ opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S);
+
// Load current ophans count from disk
_loadOrphansCount(opCtx);
LOGV2_DEBUG(6419602, 2, "Completed BalancerStatsRegistry initialization");
diff --git a/src/mongo/db/s/balancer_stats_registry.h b/src/mongo/db/s/balancer_stats_registry.h
index 473285b627c..fdc68ab1021 100644
--- a/src/mongo/db/s/balancer_stats_registry.h
+++ b/src/mongo/db/s/balancer_stats_registry.h
@@ -38,18 +38,19 @@
namespace mongo {
/**
- * Acquires the config db lock in IX mode and the collection lock for config.rangeDeletions in X
- * mode.
+ * Scoped lock to synchronize with the execution of range deletions.
+ * The range-deleter acquires a scoped lock in IX mode while orphans are being deleted.
+ * Acquiring the scoped lock in MODE_X ensures that no orphan counter in `config.rangeDeletions`
+ * entries is going to be updated concurrently.
*/
class ScopedRangeDeleterLock {
public:
- ScopedRangeDeleterLock(OperationContext* opCtx);
- ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid);
+ ScopedRangeDeleterLock(OperationContext* opCtx, LockMode mode)
+ : _resourceLock(opCtx, _mutex.getRid(), mode) {}
private:
- Lock::DBLock _configLock;
- Lock::CollectionLock _rangeDeletionLock;
- boost::optional<Lock::ResourceLock> _collectionUuidLock;
+ const Lock::ResourceLock _resourceLock;
+ static inline const Lock::ResourceMutex _mutex{"ScopedRangeDeleterLock"};
};
/**
diff --git a/src/mongo/db/s/cluster_count_cmd_d.cpp b/src/mongo/db/s/cluster_count_cmd_d.cpp
index 27c64148cc9..a593a299b4f 100644
--- a/src/mongo/db/s/cluster_count_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_count_cmd_d.cpp
@@ -62,6 +62,11 @@ struct ClusterCountCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster count command on a mongod");
+ }
};
ClusterCountCmdBase<ClusterCountCmdD> clusterCountCmdD;
diff --git a/src/mongo/db/s/cluster_find_cmd_d.cpp b/src/mongo/db/s/cluster_find_cmd_d.cpp
index e97f07b70e7..e98315c06ad 100644
--- a/src/mongo/db/s/cluster_find_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_find_cmd_d.cpp
@@ -61,6 +61,11 @@ struct ClusterFindCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster find command on a mongod");
+ }
};
ClusterFindCmdBase<ClusterFindCmdD> clusterFindCmdD;
diff --git a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
index 437e7418355..1b3eed4242a 100644
--- a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp
@@ -61,6 +61,11 @@ struct ClusterPipelineCommandD {
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster aggregate command on a mongod");
+ }
+
static AggregateCommandRequest parseAggregationRequest(
OperationContext* opCtx,
const OpMsgRequest& opMsgRequest,
diff --git a/src/mongo/db/s/cluster_write_cmd_d.cpp b/src/mongo/db/s/cluster_write_cmd_d.cpp
index c8f75d69e15..66167ab2b30 100644
--- a/src/mongo/db/s/cluster_write_cmd_d.cpp
+++ b/src/mongo/db/s/cluster_write_cmd_d.cpp
@@ -57,6 +57,11 @@ struct ClusterInsertCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster insert command on a mongod");
+ }
};
ClusterInsertCmdBase<ClusterInsertCmdD> clusterInsertCmdD;
@@ -83,6 +88,10 @@ struct ClusterUpdateCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported, "Explain on a clusterDelete is not supported");
+ }
};
ClusterUpdateCmdBase<ClusterUpdateCmdD> clusterUpdateCmdD;
@@ -109,6 +118,11 @@ struct ClusterDeleteCmdD {
// which triggers an invariant, so only shard servers can run this.
uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands());
}
+
+ static void checkCanExplainHere(OperationContext* opCtx) {
+ uasserted(ErrorCodes::CommandNotSupported,
+ "Cannot explain a cluster delete command on a mongod");
+ }
};
ClusterDeleteCmdBase<ClusterDeleteCmdD> clusterDeleteCmdD;
diff --git a/src/mongo/db/s/collection_sharding_runtime.cpp b/src/mongo/db/s/collection_sharding_runtime.cpp
index 89ceb0cfd4c..ee530f6ea04 100644
--- a/src/mongo/db/s/collection_sharding_runtime.cpp
+++ b/src/mongo/db/s/collection_sharding_runtime.cpp
@@ -486,8 +486,10 @@ void CollectionShardingRuntime::resetShardVersionRecoverRefreshFuture() {
_shardVersionInRecoverOrRefresh = boost::none;
}
-boost::optional<Timestamp> CollectionShardingRuntime::getIndexVersion(OperationContext* opCtx) {
- return _globalIndexesInfo ? _globalIndexesInfo->getVersion() : boost::none;
+boost::optional<CollectionIndexes> CollectionShardingRuntime::getCollectionIndexes(
+ OperationContext* opCtx) {
+ return _globalIndexesInfo ? boost::make_optional(_globalIndexesInfo->getCollectionIndexes())
+ : boost::none;
}
boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes(
@@ -497,22 +499,22 @@ boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes(
void CollectionShardingRuntime::addIndex(OperationContext* opCtx,
const IndexCatalogType& index,
- const Timestamp& indexVersion) {
+ const CollectionIndexes& collectionIndexes) {
if (_globalIndexesInfo) {
- _globalIndexesInfo->add(index, indexVersion);
+ _globalIndexesInfo->add(index, collectionIndexes);
} else {
IndexCatalogTypeMap indexMap;
indexMap.emplace(index.getName(), index);
- _globalIndexesInfo.emplace(indexVersion, std::move(indexMap));
+ _globalIndexesInfo.emplace(collectionIndexes, std::move(indexMap));
}
}
void CollectionShardingRuntime::removeIndex(OperationContext* opCtx,
const std::string& name,
- const Timestamp& indexVersion) {
+ const CollectionIndexes& collectionIndexes) {
tassert(
7019500, "Index information does not exist on CSR", _globalIndexesInfo.is_initialized());
- _globalIndexesInfo->remove(name, indexVersion);
+ _globalIndexesInfo->remove(name, collectionIndexes);
}
void CollectionShardingRuntime::clearIndexes(OperationContext* opCtx) {
diff --git a/src/mongo/db/s/collection_sharding_runtime.h b/src/mongo/db/s/collection_sharding_runtime.h
index 44e46d0bc4a..786425d8e92 100644
--- a/src/mongo/db/s/collection_sharding_runtime.h
+++ b/src/mongo/db/s/collection_sharding_runtime.h
@@ -229,7 +229,7 @@ public:
/**
* Gets an index version under a lock.
*/
- boost::optional<Timestamp> getIndexVersion(OperationContext* opCtx);
+ boost::optional<CollectionIndexes> getCollectionIndexes(OperationContext* opCtx);
/**
* Gets the index list under a lock.
@@ -241,17 +241,17 @@ public:
*/
void addIndex(OperationContext* opCtx,
const IndexCatalogType& index,
- const Timestamp& indexVersion);
+ const CollectionIndexes& collectionIndexes);
/**
* Removes an index from the shard-role index info under a lock.
*/
void removeIndex(OperationContext* opCtx,
const std::string& name,
- const Timestamp& indexVersion);
+ const CollectionIndexes& collectionIndexes);
/**
- * Clears the shard-role index info and set the indexVersion to boost::none.
+ * Clears the shard-role index info and set the collectionIndexes to boost::none.
*/
void clearIndexes(OperationContext* opCtx);
diff --git a/src/mongo/db/s/collection_sharding_runtime_test.cpp b/src/mongo/db/s/collection_sharding_runtime_test.cpp
index 6e792a8db0e..334225663e3 100644
--- a/src/mongo/db/s/collection_sharding_runtime_test.cpp
+++ b/src/mongo/db/s/collection_sharding_runtime_test.cpp
@@ -377,8 +377,10 @@ public:
AutoGetCollection autoColl(operationContext(), kTestNss, MODE_IX);
_uuid = autoColl.getCollection()->uuid();
- RangeDeleterService::get(operationContext())->onStepUpComplete(operationContext(), 0L);
- RangeDeleterService::get(operationContext())->_waitForRangeDeleterServiceUp_FOR_TESTING();
+ auto opCtx = operationContext();
+ RangeDeleterService::get(opCtx)->onStartup(opCtx);
+ RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L);
+ RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING();
}
void tearDown() override {
@@ -611,8 +613,7 @@ TEST_F(CollectionShardingRuntimeWithCatalogTest, TestGlobalIndexesCache) {
opCtx, kTestNss, "x_1", BSON("x" << 1), BSONObj(), uuid(), indexVersion, boost::none);
ASSERT_EQ(true, csr()->getIndexes(opCtx).is_initialized());
- ASSERT_EQ(indexVersion, *csr()->getIndexes(opCtx)->getVersion());
- ASSERT_EQ(indexVersion, *csr()->getIndexVersion(opCtx));
+ ASSERT_EQ(CollectionIndexes(uuid(), indexVersion), *csr()->getCollectionIndexes(opCtx));
}
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/s/config/sharding_catalog_manager.h b/src/mongo/db/s/config/sharding_catalog_manager.h
index 4e8a64f59a6..fba732dce7a 100644
--- a/src/mongo/db/s/config/sharding_catalog_manager.h
+++ b/src/mongo/db/s/config/sharding_catalog_manager.h
@@ -359,7 +359,8 @@ public:
*/
void splitOrMarkJumbo(OperationContext* opCtx,
const NamespaceString& nss,
- const BSONObj& minKey);
+ const BSONObj& minKey,
+ boost::optional<int64_t> optMaxChunkSizeBytes);
/**
* In a transaction, sets the 'allowMigrations' to the requested state and bumps the collection
diff --git a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
index 2b0692acb5d..93ac3b56099 100644
--- a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
+++ b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp
@@ -1775,19 +1775,31 @@ void ShardingCatalogManager::bumpMultipleCollectionVersionsAndChangeMetadataInTx
void ShardingCatalogManager::splitOrMarkJumbo(OperationContext* opCtx,
const NamespaceString& nss,
- const BSONObj& minKey) {
+ const BSONObj& minKey,
+ boost::optional<int64_t> optMaxChunkSizeBytes) {
const auto cm = uassertStatusOK(
Grid::get(opCtx)->catalogCache()->getShardedCollectionRoutingInfoWithRefresh(opCtx, nss));
auto chunk = cm.findIntersectingChunkWithSimpleCollation(minKey);
try {
- const auto splitPoints = uassertStatusOK(shardutil::selectChunkSplitPoints(
- opCtx,
- chunk.getShardId(),
- nss,
- cm.getShardKeyPattern(),
- ChunkRange(chunk.getMin(), chunk.getMax()),
- Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes()));
+ const auto maxChunkSizeBytes = [&]() -> int64_t {
+ if (optMaxChunkSizeBytes.has_value()) {
+ return *optMaxChunkSizeBytes;
+ }
+
+ auto coll = Grid::get(opCtx)->catalogClient()->getCollection(
+ opCtx, nss, repl::ReadConcernLevel::kMajorityReadConcern);
+ return coll.getMaxChunkSizeBytes().value_or(
+ Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes());
+ }();
+
+ const auto splitPoints = uassertStatusOK(
+ shardutil::selectChunkSplitPoints(opCtx,
+ chunk.getShardId(),
+ nss,
+ cm.getShardKeyPattern(),
+ ChunkRange(chunk.getMin(), chunk.getMax()),
+ maxChunkSizeBytes));
if (splitPoints.empty()) {
LOGV2(21873,
diff --git a/src/mongo/db/s/query_analysis_op_observer.cpp b/src/mongo/db/s/query_analysis_op_observer.cpp
index 84606c7f4c7..658c50c992d 100644
--- a/src/mongo/db/s/query_analysis_op_observer.cpp
+++ b/src/mongo/db/s/query_analysis_op_observer.cpp
@@ -31,6 +31,7 @@
#include "mongo/db/s/query_analysis_coordinator.h"
#include "mongo/db/s/query_analysis_op_observer.h"
+#include "mongo/db/s/query_analysis_writer.h"
#include "mongo/logv2/log.h"
#include "mongo/s/analyze_shard_key_util.h"
#include "mongo/s/catalog/type_mongos.h"
@@ -88,6 +89,17 @@ void QueryAnalysisOpObserver::onUpdate(OperationContext* opCtx, const OplogUpdat
});
}
}
+
+ if (analyze_shard_key::supportsPersistingSampledQueries() && args.updateArgs->sampleId &&
+ args.updateArgs->preImageDoc && opCtx->writesAreReplicated()) {
+ analyze_shard_key::QueryAnalysisWriter::get(opCtx)
+ .addDiff(*args.updateArgs->sampleId,
+ args.coll->ns(),
+ args.coll->uuid(),
+ *args.updateArgs->preImageDoc,
+ args.updateArgs->updatedDoc)
+ .getAsync([](auto) {});
+ }
}
void QueryAnalysisOpObserver::aboutToDelete(OperationContext* opCtx,
diff --git a/src/mongo/db/s/query_analysis_writer.cpp b/src/mongo/db/s/query_analysis_writer.cpp
new file mode 100644
index 00000000000..19b942e1775
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer.cpp
@@ -0,0 +1,695 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/s/query_analysis_writer.h"
+
+#include "mongo/client/connpool.h"
+#include "mongo/db/catalog/collection_catalog.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/ops/write_ops.h"
+#include "mongo/db/server_options.h"
+#include "mongo/db/update/document_diff_calculator.h"
+#include "mongo/executor/network_interface_factory.h"
+#include "mongo/executor/thread_pool_task_executor.h"
+#include "mongo/logv2/log.h"
+#include "mongo/s/analyze_shard_key_documents_gen.h"
+#include "mongo/s/analyze_shard_key_server_parameters_gen.h"
+#include "mongo/s/write_ops/batched_command_response.h"
+#include "mongo/util/concurrency/thread_pool.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault
+
+namespace mongo {
+namespace analyze_shard_key {
+
+namespace {
+
+MONGO_FAIL_POINT_DEFINE(disableQueryAnalysisWriter);
+MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingLocally);
+MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingRemotely);
+
+const auto getQueryAnalysisWriter = ServiceContext::declareDecoration<QueryAnalysisWriter>();
+
+constexpr int kMaxRetriesOnRetryableErrors = 5;
+const WriteConcernOptions kMajorityWriteConcern{WriteConcernOptions::kMajority,
+ WriteConcernOptions::SyncMode::UNSET,
+ WriteConcernOptions::kWriteConcernTimeoutSystem};
+
+// The size limit for the documents to an insert in a single batch. Leave some padding for other
+// fields in the insert command.
+constexpr int kMaxBSONObjSizeForInsert = BSONObjMaxUserSize - 500 * 1024;
+
+/*
+ * Returns true if this mongod can accept writes to the given collection. Unless the collection is
+ * in the "local" database, this will only return true if this mongod is a primary (or a
+ * standalone).
+ */
+bool canAcceptWrites(OperationContext* opCtx, const NamespaceString& ns) {
+ ShouldNotConflictWithSecondaryBatchApplicationBlock noPBWMBlock(opCtx->lockState());
+ Lock::DBLock lk(opCtx, ns.dbName(), MODE_IS);
+ Lock::CollectionLock lock(opCtx, ns, MODE_IS);
+ return mongo::repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(opCtx,
+ ns.db());
+}
+
+/*
+ * Runs the given write command against the given database locally, asserts that the top-level
+ * command is OK, then asserts the write status using the given 'uassertWriteStatusCb' callback.
+ * Returns the command response.
+ */
+BSONObj executeWriteCommandLocal(OperationContext* opCtx,
+ const std::string dbName,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ DBDirectClient client(opCtx);
+ BSONObj resObj;
+
+ if (!client.runCommand(dbName, cmdObj, resObj)) {
+ uassertStatusOK(getStatusFromCommandResult(resObj));
+ }
+ uassertWriteStatusCb(resObj);
+
+ return resObj;
+}
+
+/*
+ * Runs the given write command against the given database on the (remote) primary, asserts that the
+ * top-level command is OK, then asserts the write status using the given 'uassertWriteStatusCb'
+ * callback. Throws a PrimarySteppedDown error if no primary is found. Returns the command response.
+ */
+BSONObj executeWriteCommandRemote(OperationContext* opCtx,
+ const std::string dbName,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ auto hostAndPort = repl::ReplicationCoordinator::get(opCtx)->getCurrentPrimaryHostAndPort();
+
+ if (hostAndPort.empty()) {
+ uasserted(ErrorCodes::PrimarySteppedDown, "No primary exists currently");
+ }
+
+ auto conn = std::make_unique<ScopedDbConnection>(hostAndPort.toString());
+
+ if (auth::isInternalAuthSet()) {
+ uassertStatusOK(conn->get()->authenticateInternalUser());
+ }
+
+ DBClientBase* client = conn->get();
+ ScopeGuard guard([&] { conn->done(); });
+ try {
+ BSONObj resObj;
+
+ if (!client->runCommand(dbName, cmdObj, resObj)) {
+ uassertStatusOK(getStatusFromCommandResult(resObj));
+ }
+ uassertWriteStatusCb(resObj);
+
+ return resObj;
+ } catch (...) {
+ guard.dismiss();
+ conn->kill();
+ throw;
+ }
+}
+
+/*
+ * Runs the given write command against the given collection. If this mongod is currently the
+ * primary, runs the write command locally. Otherwise, runs the command on the remote primary.
+ * Internally asserts that the top-level command is OK, then asserts the write status using the
+ * given 'uassertWriteStatusCb' callback. Internally retries the write command on retryable errors
+ * (for kMaxRetriesOnRetryableErrors times) so the writes must be idempotent. Returns the
+ * command response.
+ */
+BSONObj executeWriteCommand(OperationContext* opCtx,
+ const NamespaceString& ns,
+ const BSONObj& cmdObj,
+ const std::function<void(const BSONObj&)>& uassertWriteStatusCb) {
+ const auto dbName = ns.db().toString();
+ auto numRetries = 0;
+
+ while (true) {
+ try {
+ if (canAcceptWrites(opCtx, ns)) {
+ // There is a window here where this mongod may step down after check above. In this
+ // case, a NotWritablePrimary error would be thrown. However, this is preferable to
+ // running the command while holding locks.
+ hangQueryAnalysisWriterBeforeWritingLocally.pauseWhileSet(opCtx);
+ return executeWriteCommandLocal(opCtx, dbName, cmdObj, uassertWriteStatusCb);
+ }
+
+ hangQueryAnalysisWriterBeforeWritingRemotely.pauseWhileSet(opCtx);
+ return executeWriteCommandRemote(opCtx, dbName, cmdObj, uassertWriteStatusCb);
+ } catch (DBException& ex) {
+ if (ErrorCodes::isRetriableError(ex) && numRetries < kMaxRetriesOnRetryableErrors) {
+ numRetries++;
+ continue;
+ }
+ throw;
+ }
+ }
+
+ return {};
+}
+
+struct SampledWriteCommandRequest {
+ UUID sampleId;
+ NamespaceString nss;
+ BSONObj cmd; // the BSON for a {Update,Delete,FindAndModify}CommandRequest
+};
+
+/*
+ * Returns a sampled update command for the update at 'opIndex' in the given update command.
+ */
+SampledWriteCommandRequest makeSampledUpdateCommandRequest(
+ const write_ops::UpdateCommandRequest& originalCmd, int opIndex) {
+ auto op = originalCmd.getUpdates()[opIndex];
+ auto sampleId = op.getSampleId();
+ invariant(sampleId);
+
+ write_ops::UpdateCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)});
+ sampledCmd.setLet(originalCmd.getLet());
+
+ return {*sampleId,
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+/*
+ * Returns a sampled delete command for the delete at 'opIndex' in the given delete command.
+ */
+SampledWriteCommandRequest makeSampledDeleteCommandRequest(
+ const write_ops::DeleteCommandRequest& originalCmd, int opIndex) {
+ auto op = originalCmd.getDeletes()[opIndex];
+ auto sampleId = op.getSampleId();
+ invariant(sampleId);
+
+ write_ops::DeleteCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)});
+ sampledCmd.setLet(originalCmd.getLet());
+
+ return {*sampleId,
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+/*
+ * Returns a sampled findAndModify command for the given findAndModify command.
+ */
+SampledWriteCommandRequest makeSampledFindAndModifyCommandRequest(
+ const write_ops::FindAndModifyCommandRequest& originalCmd) {
+ invariant(originalCmd.getSampleId());
+
+ write_ops::FindAndModifyCommandRequest sampledCmd(originalCmd.getNamespace());
+ sampledCmd.setQuery(originalCmd.getQuery());
+ sampledCmd.setUpdate(originalCmd.getUpdate());
+ sampledCmd.setRemove(originalCmd.getRemove());
+ sampledCmd.setUpsert(originalCmd.getUpsert());
+ sampledCmd.setNew(originalCmd.getNew());
+ sampledCmd.setSort(originalCmd.getSort());
+ sampledCmd.setCollation(originalCmd.getCollation());
+ sampledCmd.setArrayFilters(originalCmd.getArrayFilters());
+ sampledCmd.setLet(originalCmd.getLet());
+ sampledCmd.setSampleId(originalCmd.getSampleId());
+
+ return {*sampledCmd.getSampleId(),
+ sampledCmd.getNamespace(),
+ sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))};
+}
+
+} // namespace
+
+QueryAnalysisWriter& QueryAnalysisWriter::get(OperationContext* opCtx) {
+ return get(opCtx->getServiceContext());
+}
+
+QueryAnalysisWriter& QueryAnalysisWriter::get(ServiceContext* serviceContext) {
+ invariant(analyze_shard_key::isFeatureFlagEnabledIgnoreFCV(),
+ "Only support analyzing queries when the feature flag is enabled");
+ invariant(serverGlobalParams.clusterRole == ClusterRole::ShardServer,
+ "Only support analyzing queries on a sharded cluster");
+ return getQueryAnalysisWriter(serviceContext);
+}
+
+void QueryAnalysisWriter::onStartup() {
+ auto serviceContext = getQueryAnalysisWriter.owner(this);
+ auto periodicRunner = serviceContext->getPeriodicRunner();
+ invariant(periodicRunner);
+
+ stdx::lock_guard<Latch> lk(_mutex);
+
+ PeriodicRunner::PeriodicJob QueryWriterJob(
+ "QueryAnalysisQueryWriter",
+ [this](Client* client) {
+ if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) {
+ return;
+ }
+ auto opCtx = client->makeOperationContext();
+ _flushQueries(opCtx.get());
+ },
+ Seconds(gQueryAnalysisWriterIntervalSecs));
+ _periodicQueryWriter = periodicRunner->makeJob(std::move(QueryWriterJob));
+ _periodicQueryWriter.start();
+
+ PeriodicRunner::PeriodicJob diffWriterJob(
+ "QueryAnalysisDiffWriter",
+ [this](Client* client) {
+ if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) {
+ return;
+ }
+ auto opCtx = client->makeOperationContext();
+ _flushDiffs(opCtx.get());
+ },
+ Seconds(gQueryAnalysisWriterIntervalSecs));
+ _periodicDiffWriter = periodicRunner->makeJob(std::move(diffWriterJob));
+ _periodicDiffWriter.start();
+
+ ThreadPool::Options threadPoolOptions;
+ threadPoolOptions.maxThreads = gQueryAnalysisWriterMaxThreadPoolSize;
+ threadPoolOptions.minThreads = gQueryAnalysisWriterMinThreadPoolSize;
+ threadPoolOptions.threadNamePrefix = "QueryAnalysisWriter-";
+ threadPoolOptions.poolName = "QueryAnalysisWriterThreadPool";
+ threadPoolOptions.onCreateThread = [](const std::string& threadName) {
+ Client::initThread(threadName.c_str());
+ };
+ _executor = std::make_shared<executor::ThreadPoolTaskExecutor>(
+ std::make_unique<ThreadPool>(threadPoolOptions),
+ executor::makeNetworkInterface("QueryAnalysisWriterNetwork"));
+ _executor->startup();
+}
+
+void QueryAnalysisWriter::onShutdown() {
+ if (_executor) {
+ _executor->shutdown();
+ _executor->join();
+ }
+ if (_periodicQueryWriter.isValid()) {
+ _periodicQueryWriter.stop();
+ }
+ if (_periodicDiffWriter.isValid()) {
+ _periodicDiffWriter.stop();
+ }
+}
+
+void QueryAnalysisWriter::_flushQueries(OperationContext* opCtx) {
+ try {
+ _flush(opCtx, NamespaceString::kConfigSampledQueriesNamespace, &_queries);
+ } catch (DBException& ex) {
+ LOGV2(7047300,
+ "Failed to flush queries, will try again at the next interval",
+ "error"_attr = redact(ex));
+ }
+}
+
+void QueryAnalysisWriter::_flushDiffs(OperationContext* opCtx) {
+ try {
+ _flush(opCtx, NamespaceString::kConfigSampledQueriesDiffNamespace, &_diffs);
+ } catch (DBException& ex) {
+ LOGV2(7075400,
+ "Failed to flush diffs, will try again at the next interval",
+ "error"_attr = redact(ex));
+ }
+}
+
+void QueryAnalysisWriter::_flush(OperationContext* opCtx,
+ const NamespaceString& ns,
+ Buffer* buffer) {
+ Buffer tmpBuffer;
+ // The indices of invalid documents, e.g. documents that fail to insert with DuplicateKey errors
+ // (i.e. duplicates) and BadValue errors. Such documents should not get added back to the buffer
+ // when the inserts below fail.
+ std::set<int> invalid;
+ {
+ stdx::lock_guard<Latch> lk(_mutex);
+ if (buffer->isEmpty()) {
+ return;
+ }
+ std::swap(tmpBuffer, *buffer);
+ }
+ ScopeGuard backSwapper([&] {
+ stdx::lock_guard<Latch> lk(_mutex);
+ for (int i = 0; i < tmpBuffer.getCount(); i++) {
+ if (invalid.find(i) == invalid.end()) {
+ buffer->add(tmpBuffer.at(i));
+ }
+ }
+ });
+
+ // Insert the documents in batches from the back of the buffer so that we don't need to move the
+ // documents forward after each batch.
+ size_t baseIndex = tmpBuffer.getCount() - 1;
+ size_t maxBatchSize = gQueryAnalysisWriterMaxBatchSize.load();
+
+ while (!tmpBuffer.isEmpty()) {
+ std::vector<BSONObj> docsToInsert;
+ long long objSize = 0;
+
+ size_t lastIndex = tmpBuffer.getCount(); // inclusive
+ while (lastIndex > 0 && docsToInsert.size() < maxBatchSize) {
+ // Check if the next document can fit in the batch.
+ auto doc = tmpBuffer.at(lastIndex - 1);
+ if (doc.objsize() + objSize >= kMaxBSONObjSizeForInsert) {
+ break;
+ }
+ lastIndex--;
+ objSize += doc.objsize();
+ docsToInsert.push_back(std::move(doc));
+ }
+ // We don't add a document that is above the size limit to the buffer so we should have
+ // added at least one document to 'docsToInsert'.
+ invariant(!docsToInsert.empty());
+
+ write_ops::InsertCommandRequest insertCmd(ns);
+ insertCmd.setDocuments(std::move(docsToInsert));
+ insertCmd.setWriteCommandRequestBase([&] {
+ write_ops::WriteCommandRequestBase wcb;
+ wcb.setOrdered(false);
+ wcb.setBypassDocumentValidation(false);
+ return wcb;
+ }());
+ auto insertCmdBson = insertCmd.toBSON(
+ {BSON(WriteConcernOptions::kWriteConcernField << kMajorityWriteConcern.toBSON())});
+
+ executeWriteCommand(opCtx, ns, std::move(insertCmdBson), [&](const BSONObj& resObj) {
+ BatchedCommandResponse response;
+ std::string errMsg;
+
+ if (!response.parseBSON(resObj, &errMsg)) {
+ uasserted(ErrorCodes::FailedToParse, errMsg);
+ }
+
+ if (response.isErrDetailsSet() && response.sizeErrDetails() > 0) {
+ boost::optional<write_ops::WriteError> firstWriteErr;
+
+ for (const auto& err : response.getErrDetails()) {
+ if (err.getStatus() == ErrorCodes::DuplicateKey ||
+ err.getStatus() == ErrorCodes::BadValue) {
+ LOGV2(7075402,
+ "Ignoring insert error",
+ "error"_attr = redact(err.getStatus()));
+ invalid.insert(baseIndex - err.getIndex());
+ continue;
+ }
+ if (!firstWriteErr) {
+ // Save the error for later. Go through the rest of the errors to see if
+ // there are any invalid documents so they can be discarded from the buffer.
+ firstWriteErr.emplace(err);
+ }
+ }
+ if (firstWriteErr) {
+ uassertStatusOK(firstWriteErr->getStatus());
+ }
+ } else {
+ uassertStatusOK(response.toStatus());
+ }
+ });
+
+ tmpBuffer.truncate(lastIndex, objSize);
+ baseIndex -= lastIndex;
+ }
+
+ backSwapper.dismiss();
+}
+
+void QueryAnalysisWriter::Buffer::add(BSONObj doc) {
+ if (doc.objsize() > kMaxBSONObjSizeForInsert) {
+ return;
+ }
+ _docs.push_back(std::move(doc));
+ _numBytes += _docs.back().objsize();
+}
+
+void QueryAnalysisWriter::Buffer::truncate(size_t index, long long numBytes) {
+ invariant(index >= 0);
+ invariant(index < _docs.size());
+ invariant(numBytes > 0);
+ invariant(numBytes <= _numBytes);
+ _docs.resize(index);
+ _numBytes -= numBytes;
+}
+
+bool QueryAnalysisWriter::_exceedsMaxSizeBytes() {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _queries.getSize() + _diffs.getSize() >= gQueryAnalysisWriterMaxMemoryUsageBytes.load();
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addFindQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kFind, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addCountQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kCount, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDistinctQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kDistinct, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addAggregateQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kAggregate, filter, collation);
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::_addReadQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ invariant(_executor);
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampleId,
+ nss,
+ cmdName,
+ filter = filter.getOwned(),
+ collation = collation.getOwned()] {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid = CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, nss);
+
+ if (!collUuid) {
+ LOGV2(7047301, "Found a sampled read query for non-existing collection");
+ return;
+ }
+
+ auto cmd = SampledReadCommand{filter.getOwned(), collation.getOwned()};
+ auto doc = SampledReadQueryDocument{sampleId, nss, *collUuid, cmdName, cmd.toBSON()};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss](Status status) {
+ LOGV2(7047302,
+ "Failed to add read query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addUpdateQuery(
+ const write_ops::UpdateCommandRequest& updateCmd, int opIndex) {
+ invariant(updateCmd.getUpdates()[opIndex].getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this, sampledUpdateCmd = makeSampledUpdateCommandRequest(updateCmd, opIndex)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledUpdateCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075300,
+ "Found a sampled update query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledUpdateCmd.sampleId,
+ sampledUpdateCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kUpdate,
+ std::move(sampledUpdateCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = updateCmd.getNamespace()](Status status) {
+ LOGV2(7075301,
+ "Failed to add update query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDeleteQuery(
+ const write_ops::DeleteCommandRequest& deleteCmd, int opIndex) {
+ invariant(deleteCmd.getDeletes()[opIndex].getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this, sampledDeleteCmd = makeSampledDeleteCommandRequest(deleteCmd, opIndex)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledDeleteCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075302,
+ "Found a sampled delete query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledDeleteCmd.sampleId,
+ sampledDeleteCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kDelete,
+ std::move(sampledDeleteCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = deleteCmd.getNamespace()](Status status) {
+ LOGV2(7075303,
+ "Failed to add delete query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addFindAndModifyQuery(
+ const write_ops::FindAndModifyCommandRequest& findAndModifyCmd) {
+ invariant(findAndModifyCmd.getSampleId());
+ invariant(_executor);
+
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampledFindAndModifyCmd =
+ makeSampledFindAndModifyCommandRequest(findAndModifyCmd)]() {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+
+ auto collUuid =
+ CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledFindAndModifyCmd.nss);
+
+ if (!collUuid) {
+ LOGV2_WARNING(7075304,
+ "Found a sampled findAndModify query for a non-existing collection");
+ return;
+ }
+
+ auto doc = SampledWriteQueryDocument{sampledFindAndModifyCmd.sampleId,
+ sampledFindAndModifyCmd.nss,
+ *collUuid,
+ SampledWriteCommandNameEnum::kFindAndModify,
+ std::move(sampledFindAndModifyCmd.cmd)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _queries.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushQueries(opCtx);
+ }
+ })
+ .onError([this, nss = findAndModifyCmd.getNamespace()](Status status) {
+ LOGV2(7075305,
+ "Failed to add findAndModify query",
+ "ns"_attr = nss,
+ "error"_attr = redact(status));
+ });
+}
+
+ExecutorFuture<void> QueryAnalysisWriter::addDiff(const UUID& sampleId,
+ const NamespaceString& nss,
+ const UUID& collUuid,
+ const BSONObj& preImage,
+ const BSONObj& postImage) {
+ invariant(_executor);
+ return ExecutorFuture<void>(_executor)
+ .then([this,
+ sampleId,
+ nss,
+ collUuid,
+ preImage = preImage.getOwned(),
+ postImage = postImage.getOwned()]() {
+ auto diff = doc_diff::computeInlineDiff(preImage, postImage);
+
+ if (!diff || diff->isEmpty()) {
+ return;
+ }
+
+ auto doc = SampledQueryDiffDocument{sampleId, nss, collUuid, std::move(*diff)};
+ stdx::lock_guard<Latch> lk(_mutex);
+ _diffs.add(doc.toBSON());
+ })
+ .then([this] {
+ if (_exceedsMaxSizeBytes()) {
+ auto opCtxHolder = cc().makeOperationContext();
+ auto opCtx = opCtxHolder.get();
+ _flushDiffs(opCtx);
+ }
+ })
+ .onError([this, nss](Status status) {
+ LOGV2(7075401, "Failed to add diff", "ns"_attr = nss, "error"_attr = redact(status));
+ });
+}
+
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/query_analysis_writer.h b/src/mongo/db/s/query_analysis_writer.h
new file mode 100644
index 00000000000..508d4903b65
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer.h
@@ -0,0 +1,214 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/operation_context.h"
+#include "mongo/db/service_context.h"
+#include "mongo/executor/task_executor.h"
+#include "mongo/s/analyze_shard_key_common_gen.h"
+#include "mongo/s/analyze_shard_key_util.h"
+#include "mongo/s/write_ops/batched_command_request.h"
+#include "mongo/util/periodic_runner.h"
+
+namespace mongo {
+namespace analyze_shard_key {
+
+/**
+ * Owns the machinery for persisting sampled queries. That consists of the following:
+ * - The buffer that stores sampled queries and the periodic background job that inserts those
+ * queries into the local config.sampledQueries collection.
+ * - The buffer that stores diffs for sampled update queries and the periodic background job that
+ * inserts those diffs into the local config.sampledQueriesDiff collection.
+ *
+ * Currently, query sampling is only supported on a sharded cluster. So a writer must be a shardsvr
+ * mongod. If the mongod is a primary, it will execute the insert commands locally. If it is a
+ * secondary, it will perform the insert commands against the primary.
+ *
+ * The memory usage of the buffers is controlled by the 'queryAnalysisWriterMaxMemoryUsageBytes'
+ * server parameter. Upon adding a query or a diff that causes the total size of buffers to exceed
+ * the limit, the writer will flush the corresponding buffer immediately instead of waiting for it
+ * to get flushed later by the periodic job.
+ */
+class QueryAnalysisWriter final : public std::enable_shared_from_this<QueryAnalysisWriter> {
+ QueryAnalysisWriter(const QueryAnalysisWriter&) = delete;
+ QueryAnalysisWriter& operator=(const QueryAnalysisWriter&) = delete;
+
+public:
+ /**
+ * Temporarily stores documents to be written to disk.
+ */
+ struct Buffer {
+ public:
+ /**
+ * Adds the given document to the buffer if its size is below the limit (i.e.
+ * BSONObjMaxUserSize - some padding) and increments the total number of bytes accordingly.
+ */
+ void add(BSONObj doc);
+
+ /**
+ * Removes the documents at 'index' onwards from the buffer and decrements the total number
+ * of the bytes by 'numBytes'. The caller must ensure that that 'numBytes' is indeed the
+ * total size of the documents being removed.
+ */
+ void truncate(size_t index, long long numBytes);
+
+ bool isEmpty() const {
+ return _docs.empty();
+ }
+
+ int getCount() const {
+ return _docs.size();
+ }
+
+ long long getSize() const {
+ return _numBytes;
+ }
+
+ BSONObj at(size_t index) const {
+ return _docs[index];
+ }
+
+ private:
+ std::vector<BSONObj> _docs;
+ long long _numBytes = 0;
+ };
+
+ QueryAnalysisWriter() = default;
+ ~QueryAnalysisWriter() = default;
+
+ QueryAnalysisWriter(QueryAnalysisWriter&& source) = delete;
+ QueryAnalysisWriter& operator=(QueryAnalysisWriter&& other) = delete;
+
+ /**
+ * Obtains the service-wide QueryAnalysisWriter instance.
+ */
+ static QueryAnalysisWriter& get(OperationContext* opCtx);
+ static QueryAnalysisWriter& get(ServiceContext* serviceContext);
+
+ void onStartup();
+
+ void onShutdown();
+
+ ExecutorFuture<void> addFindQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addCountQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addDistinctQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addAggregateQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ ExecutorFuture<void> addUpdateQuery(const write_ops::UpdateCommandRequest& updateCmd,
+ int opIndex);
+ ExecutorFuture<void> addDeleteQuery(const write_ops::DeleteCommandRequest& deleteCmd,
+ int opIndex);
+
+ ExecutorFuture<void> addFindAndModifyQuery(
+ const write_ops::FindAndModifyCommandRequest& findAndModifyCmd);
+
+ ExecutorFuture<void> addDiff(const UUID& sampleId,
+ const NamespaceString& nss,
+ const UUID& collUuid,
+ const BSONObj& preImage,
+ const BSONObj& postImage);
+
+ int getQueriesCountForTest() const {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _queries.getCount();
+ }
+
+ void flushQueriesForTest(OperationContext* opCtx) {
+ _flushQueries(opCtx);
+ }
+
+ int getDiffsCountForTest() const {
+ stdx::lock_guard<Latch> lk(_mutex);
+ return _diffs.getCount();
+ }
+
+ void flushDiffsForTest(OperationContext* opCtx) {
+ _flushDiffs(opCtx);
+ }
+
+private:
+ ExecutorFuture<void> _addReadQuery(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation);
+
+ void _flushQueries(OperationContext* opCtx);
+ void _flushDiffs(OperationContext* opCtx);
+
+ /**
+ * The helper for '_flushQueries' and '_flushDiffs'. Inserts the documents in 'buffer' into the
+ * collection 'ns' in batches, and removes all the inserted documents from 'buffer'. Internally
+ * retries the inserts on retryable errors for a fixed number of times. Ignores DuplicateKey
+ * errors since they are expected for the following reasons:
+ * - For the query buffer, a sampled query that is idempotent (e.g. a read or retryable write)
+ * could get added to the buffer (across nodes) more than once due to retries.
+ * - For the diff buffer, a sampled multi-update query could end up generating multiple diffs
+ * and each diff is identified using the sample id of the sampled query that creates it.
+ *
+ * Throws an error if the inserts fail with any other error.
+ */
+ void _flush(OperationContext* opCtx, const NamespaceString& nss, Buffer* buffer);
+
+ /**
+ * Returns true if the total size of the buffered queries and diffs has exceeded the maximum
+ * amount of memory that the writer is allowed to use.
+ */
+ bool _exceedsMaxSizeBytes();
+
+ mutable Mutex _mutex = MONGO_MAKE_LATCH("QueryAnalysisWriter::_mutex");
+
+ PeriodicJobAnchor _periodicQueryWriter;
+ Buffer _queries;
+
+ PeriodicJobAnchor _periodicDiffWriter;
+ Buffer _diffs;
+
+ // Initialized on startup and joined on shutdown.
+ std::shared_ptr<executor::TaskExecutor> _executor;
+};
+
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/query_analysis_writer_test.cpp b/src/mongo/db/s/query_analysis_writer_test.cpp
new file mode 100644
index 00000000000..b17f39c82df
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer_test.cpp
@@ -0,0 +1,1281 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/s/query_analysis_writer.h"
+
+#include "mongo/bson/unordered_fields_bsonobj_comparator.h"
+#include "mongo/db/db_raii.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/s/shard_server_test_fixture.h"
+#include "mongo/db/update/document_diff_calculator.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/logv2/log.h"
+#include "mongo/s/analyze_shard_key_documents_gen.h"
+#include "mongo/unittest/death_test.h"
+#include "mongo/util/fail_point.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
+
+namespace mongo {
+namespace analyze_shard_key {
+namespace {
+
+TEST(QueryAnalysisWriterBufferTest, AddBasic) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc0 = BSON("a" << 0);
+ buffer.add(doc0);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize());
+
+ auto doc1 = BSON("a" << BSON_ARRAY(0 << 1 << 2));
+ buffer.add(doc1);
+ ASSERT_EQ(buffer.getCount(), 2);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize() + doc1.objsize());
+
+ ASSERT_BSONOBJ_EQ(buffer.at(0), doc0);
+ ASSERT_BSONOBJ_EQ(buffer.at(1), doc1);
+}
+
+TEST(QueryAnalysisWriterBufferTest, AddTooLarge) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 0);
+ ASSERT_EQ(buffer.getSize(), 0);
+}
+
+TEST(QueryAnalysisWriterBufferTest, TruncateBasic) {
+ auto testTruncateCommon = [](int oldCount, int newCount) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ std::vector<BSONObj> docs;
+ for (auto i = 0; i < oldCount; i++) {
+ docs.push_back(BSON("a" << i));
+ }
+ // The documents have the same size.
+ auto docSize = docs.back().objsize();
+
+ for (const auto& doc : docs) {
+ buffer.add(doc);
+ }
+ ASSERT_EQ(buffer.getCount(), oldCount);
+ ASSERT_EQ(buffer.getSize(), oldCount * docSize);
+
+ buffer.truncate(newCount, (oldCount - newCount) * docSize);
+ ASSERT_EQ(buffer.getCount(), newCount);
+ ASSERT_EQ(buffer.getSize(), newCount * docSize);
+ for (auto i = 0; i < newCount; i++) {
+ ASSERT_BSONOBJ_EQ(buffer.at(i), docs[i]);
+ }
+ };
+
+ testTruncateCommon(10 /* oldCount */, 6 /* newCount */);
+ testTruncateCommon(10 /* oldCount */, 0 /* newCount */); // Truncate all.
+ testTruncateCommon(10 /* oldCount */, 9 /* newCount */); // Truncate one.
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(-1, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(2, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, -doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Zero, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, 0);
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, doc.objsize() * 2);
+}
+
+void assertBsonObjEqualUnordered(const BSONObj& lhs, const BSONObj& rhs) {
+ UnorderedFieldsBSONObjComparator comparator;
+ ASSERT_EQ(comparator.compare(lhs, rhs), 0);
+}
+
+struct QueryAnalysisWriterTest : public ShardServerTestFixture {
+public:
+ void setUp() {
+ ShardServerTestFixture::setUp();
+ QueryAnalysisWriter::get(operationContext()).onStartup();
+
+ DBDirectClient client(operationContext());
+ client.createCollection(nss0.toString());
+ client.createCollection(nss1.toString());
+ }
+
+ void tearDown() {
+ QueryAnalysisWriter::get(operationContext()).onShutdown();
+ ShardServerTestFixture::tearDown();
+ }
+
+protected:
+ UUID getCollectionUUID(const NamespaceString& nss) const {
+ auto collectionCatalog = CollectionCatalog::get(operationContext());
+ return *collectionCatalog->lookupUUIDByNSS(operationContext(), nss);
+ }
+
+ BSONObj makeNonEmptyFilter() {
+ return BSON("_id" << UUID::gen());
+ }
+
+ BSONObj makeNonEmptyCollation() {
+ int strength = rand() % 5 + 1;
+ return BSON("locale"
+ << "en_US"
+ << "strength" << strength);
+ }
+
+ /*
+ * Makes an UpdateCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' updates and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the UpdateCommandRequest and the map storing the expected sampled
+ * UpdateCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::UpdateCommandRequest, std::map<UUID, write_ops::UpdateCommandRequest>>
+ makeUpdateCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::UpdateCommandRequest originalCmd(nss);
+ std::vector<write_ops::UpdateOpEntry> updateOps; // populated below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::UpdateCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto updateOp = write_ops::UpdateOpEntry(
+ BSON(filterFieldName << i),
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << i))));
+ updateOp.setC(BSON("x" << i));
+ updateOp.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << i))});
+ updateOp.setMulti(_getRandomBool());
+ updateOp.setUpsert(_getRandomBool());
+ updateOp.setUpsertSupplied(_getRandomBool());
+ updateOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ updateOp.setSampleId(sampleId);
+
+ write_ops::UpdateCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setUpdates({updateOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ updateOps.push_back(updateOp);
+ }
+ originalCmd.setUpdates(updateOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes an DeleteCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' deletes and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the DeleteCommandRequest and the map storing the expected sampled
+ * DeleteCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::DeleteCommandRequest, std::map<UUID, write_ops::DeleteCommandRequest>>
+ makeDeleteCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::DeleteCommandRequest originalCmd(nss);
+ std::vector<write_ops::DeleteOpEntry> deleteOps; // populated and set below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::DeleteCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto deleteOp =
+ write_ops::DeleteOpEntry(BSON(filterFieldName << i), _getRandomBool() /* multi */);
+ deleteOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ deleteOp.setSampleId(sampleId);
+
+ write_ops::DeleteCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setDeletes({deleteOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ deleteOps.push_back(deleteOp);
+ }
+ originalCmd.setDeletes(deleteOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes a FindAndModifyCommandRequest for the collection 'nss'. The findAndModify is an update
+ * if 'isUpdate' is true, and a remove otherwise. If 'markForSampling' is true, it is marked for
+ * sampling. Returns the FindAndModifyCommandRequest and the map storing the expected sampled
+ * FindAndModifyCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::FindAndModifyCommandRequest,
+ std::map<UUID, write_ops::FindAndModifyCommandRequest>>
+ makeFindAndModifyCommandRequest(const NamespaceString& nss,
+ bool isUpdate,
+ bool markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::FindAndModifyCommandRequest originalCmd(nss);
+ originalCmd.setQuery(BSON(filterFieldName << 0));
+ originalCmd.setUpdate(
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << 0))));
+ originalCmd.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << 10))});
+ originalCmd.setSort(BSON("_id" << 1));
+ if (isUpdate) {
+ originalCmd.setUpsert(_getRandomBool());
+ originalCmd.setNew(_getRandomBool());
+ }
+ originalCmd.setCollation(makeNonEmptyCollation());
+ originalCmd.setLet(let);
+ originalCmd.setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::FindAndModifyCommandRequest> expectedSampledCmds;
+ if (markForSampling) {
+ auto sampleId = UUID::gen();
+ originalCmd.setSampleId(sampleId);
+
+ auto expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setEncryptionInformation(boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ void deleteSampledQueryDocuments() const {
+ DBDirectClient client(operationContext());
+ client.remove(NamespaceString::kConfigSampledQueriesNamespace.toString(), BSONObj());
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueries
+ * collection.
+ */
+ int getSampledQueryDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled read query document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertSampledReadQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledReadQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = SampledReadCommand::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.getFilter(), filter);
+ ASSERT_BSONOBJ_EQ(parsedCmd.getCollation(), collation);
+ }
+
+ /*
+ * Asserts that there is a sampled write query document with the given sample id and that it has
+ * the given fields.
+ */
+ template <typename CommandRequestType>
+ void assertSampledWriteQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledWriteCommandNameEnum cmdName,
+ const CommandRequestType& expectedCmd) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledWriteQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = CommandRequestType::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.toBSON({}), expectedCmd.toBSON({}));
+ }
+
+ /*
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueriesDiff
+ * collection.
+ */
+ int getDiffDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesDiffNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled diff document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertDiffDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& expectedDiff) {
+ auto doc =
+ _getConfigDocument(NamespaceString::kConfigSampledQueriesDiffNamespace, sampleId);
+ auto parsedDiffDoc =
+ SampledQueryDiffDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedDiffDoc.getNs(), nss);
+ ASSERT_EQ(parsedDiffDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedDiffDoc.getSampleId(), sampleId);
+ assertBsonObjEqualUnordered(parsedDiffDoc.getDiff(), expectedDiff);
+ }
+
+ const NamespaceString nss0{"testDb", "testColl0"};
+ const NamespaceString nss1{"testDb", "testColl1"};
+
+ // Test with both empty and non-empty filter and collation to verify that the
+ // QueryAnalysisWriter doesn't require filter or collation to be non-empty.
+ const BSONObj emptyFilter{};
+ const BSONObj emptyCollation{};
+
+ const BSONObj let = BSON("x" << 1);
+ // Test with EncryptionInformation to verify that QueryAnalysisWriter does not persist the
+ // WriteCommandRequestBase fields, especially this sensitive field.
+ const EncryptionInformation encryptionInformation{BSON("foo"
+ << "bar")};
+
+private:
+ bool _getRandomBool() {
+ return rand() % 2 == 0;
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'collNss' in the config collection
+ * 'configNss'.
+ */
+ int _getConfigDocumentsCount(const NamespaceString& configNss,
+ const NamespaceString& collNss) const {
+ DBDirectClient client(operationContext());
+ return client.count(configNss, BSON("ns" << collNss.toString()));
+ }
+
+ /**
+ * Returns the document with the given _id in the config collection 'configNss'.
+ */
+ BSONObj _getConfigDocument(const NamespaceString configNss, const UUID& id) const {
+ DBDirectClient client(operationContext());
+
+ FindCommandRequest findRequest{configNss};
+ findRequest.setFilter(BSON("_id" << id));
+ auto cursor = client.find(std::move(findRequest));
+ ASSERT(cursor->more());
+ return cursor->next();
+ }
+
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", true};
+ FailPointEnableBlock _fp{"disableQueryAnalysisWriter"};
+};
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetIfFeatureFlagNotEnabled, "invariant") {
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey",
+ false};
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnConfigServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::ConfigServer;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnNonShardServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::None;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, NoQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, FindQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testFindCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addFindQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testFindCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testFindCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testFindCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testFindCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, CountQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testCountCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addCountQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kCount, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testCountCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testCountCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testCountCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testCountCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DistinctQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testDistinctCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addDistinctQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kDistinct, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testDistinctCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testDistinctCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testDistinctCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testDistinctCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, AggregateQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testAggregateCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testAggregateCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testAggregateCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testAggregateCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testAggregateCmdCommon(emptyFilter, emptyCollation);
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, UpdateQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeUpdateCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addUpdateQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, UpdateQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0, 3, {0, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, DeleteQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeDeleteCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addDeleteQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, DeleteQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0, 3, {1, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 1).get();
+ writer.addDeleteQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, false /* markForSampling */);
+ writer.addFindAndModifyQuery(originalCmd).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryUpdateMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryRemoveMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, false /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, MultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have one query.
+ auto [originalDeleteCmd, expectedSampledDeleteCmds] =
+ makeDeleteCommandRequest(nss1, 3, {1} /* markForSampling */);
+ ASSERT_EQ(expectedSampledDeleteCmds.size(), 1U);
+ auto [deleteSampleId, expectedSampledDeleteCmd] = *expectedSampledDeleteCmds.begin();
+
+ // Make nss1 have two queries.
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addDeleteQuery(originalDeleteCmd, 1).get();
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addCountQuery(countSampleId, nss1, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(deleteSampleId,
+ expectedSampledDeleteCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledDeleteCmd);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 2);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(countSampleId,
+ nss1,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto findSampleId = UUID::gen();
+ auto originalFindFilter = makeNonEmptyFilter();
+ auto originalFindCollation = makeNonEmptyCollation();
+
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation)
+ .get(); // This is a duplicate.
+ writer.addCountQuery(countSampleId, nss0, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 3);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+ assertSampledReadQueryDocument(countSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numQueries = 5;
+
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = makeNonEmptyFilter();
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numQueries = 3;
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddReadIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto sampleId0 = UUID::gen();
+ auto filter0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+ auto collation0 = makeNonEmptyCollation();
+
+ auto sampleId1 = UUID::gen();
+ auto filter1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+ auto collation1 = makeNonEmptyCollation();
+
+ writer.addFindQuery(sampleId0, nss0, filter0, collation0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addAggregateQuery(sampleId1, nss1, filter1, collation1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId0, nss0, SampledReadCommandNameEnum::kFind, filter0, collation0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledReadQueryDocument(
+ sampleId1, nss1, SampledReadCommandNameEnum::kAggregate, filter1, collation1);
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddUpdateIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0,
+ 3,
+ {0, 2} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDeleteIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0,
+ 3,
+ {0, 1} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addDeleteQuery(originalCmd, 1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddFindAndModifyIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto [originalCmd0, expectedSampledCmds0] = makeFindAndModifyCommandRequest(
+ nss0,
+ true /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId0, expectedSampledCmd0] = *expectedSampledCmds0.begin();
+
+ auto [originalCmd1, expectedSampledCmds1] = makeFindAndModifyCommandRequest(
+ nss1,
+ false /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'b') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId1, expectedSampledCmd1] = *expectedSampledCmds1.begin();
+
+ writer.addFindAndModifyQuery(originalCmd0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addFindAndModifyQuery(originalCmd1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId0,
+ expectedSampledCmd0.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledWriteQueryDocument(sampleId1,
+ expectedSampledCmd1.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd1);
+}
+
+TEST_F(QueryAnalysisWriterTest, AddQueriesBackAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+ auto numQueries = 8;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+
+ // Force the documents to get inserted in three batches of size 3, 3 and 2, respectively.
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 3};
+
+ // Hang after inserting the documents in the first batch.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the second batch to fail so that it falls back to inserting one document at a time in
+ // order, and then force the first and second document in the batch to fail.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 3);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that all the documents other than the ones in the first batch got added back to the
+ // buffer after the error. That is, the error caused the last document in the second batch to
+ // get added to buffer also although it was successfully inserted since the writer did not have
+ // a way to tell if the error caused the entire command to fail early.
+ ASSERT_EQ(writer.getQueriesCountForTest(), 5);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 4);
+
+ // Flush that remaining documents. If the documents were not added back correctly, some
+ // documents would be missing and the checks below would fail.
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, RemoveDuplicatesFromBufferAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+
+ auto numQueries0 = 3;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries0; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries0);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+
+ auto numQueries1 = 5;
+
+ std::vector<UUID> sampleIds1;
+ for (auto i = 0; i < numQueries1; i++) {
+ sampleIds1.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds1[i], nss1, originalFilter, originalCollation).get();
+ // This is a duplicate.
+ if (i < numQueries0) {
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0 + numQueries1);
+
+ // Force the batch to fail so that it falls back to inserting one document at a time in order.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 1);
+
+ // Hang after inserting the first non-duplicate document.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the next non-duplicate document to fail to insert.
+ failFp->setMode(FailPoint::nTimes, 1);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that the duplicate documents did not get added back to the buffer after the error.
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries1);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1 - 1);
+
+ // Flush that remaining documents. If the documents were not added back correctly, the document
+ // that previously failed to insert would be missing and the checks below would fail.
+ failFp->setMode(FailPoint::off, 0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1);
+ for (const auto& sampleId : sampleIds1) {
+ assertSampledReadQueryDocument(
+ sampleId, nss1, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, NoDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsBasic) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON("a" << 1);
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId, nss0, *doc_diff::computeInlineDiff(preImage, postImage));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have a diff for one query.
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0 << "b" << 0 << "c" << 0);
+ auto postImage0 = BSON("a" << 0 << "b" << 1 << "d" << 1);
+
+ // Make nss1 have diffs for two queries.
+ auto collUuid1 = getCollectionUUID(nss1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1) << "d" << BSON("e" << 1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2) << "d" << BSON("e" << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ writer.addDiff(sampleId2, nss1, collUuid1, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 2);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss1, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON("a" << 1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ writer.addDiff(sampleId1, nss0, collUuid0, preImage1, postImage1).get();
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0)
+ .get(); // This is a duplicate.
+ writer.addDiff(sampleId2, nss0, collUuid0, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 3);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ assertDiffDocument(sampleId1, nss0, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss0, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numDiffs = 5;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(("a" + std::to_string(i)) << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numDiffs = 3;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDiffIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+
+ auto collUuid1 = getCollectionUUID(nss1);
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 0);
+ auto postImage1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ // Adding the next diff causes the size to exceed the limit.
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 1);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffEmpty) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 1);
+ auto postImage = preImage;
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ auto postImage = BSONObj();
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+} // namespace
+} // namespace analyze_shard_key
+} // namespace mongo
diff --git a/src/mongo/db/s/range_deleter_service.cpp b/src/mongo/db/s/range_deleter_service.cpp
index acabcc40bea..ca80b24aa3e 100644
--- a/src/mongo/db/s/range_deleter_service.cpp
+++ b/src/mongo/db/s/range_deleter_service.cpp
@@ -86,6 +86,50 @@ RangeDeleterService* RangeDeleterService::get(OperationContext* opCtx) {
return get(opCtx->getServiceContext());
}
+RangeDeleterService::ReadyRangeDeletionsProcessor::ReadyRangeDeletionsProcessor(
+ OperationContext* opCtx)
+ : _thread([this] { _runRangeDeletions(); }) {}
+
+RangeDeleterService::ReadyRangeDeletionsProcessor::~ReadyRangeDeletionsProcessor() {
+ shutdown();
+ invariant(_thread.joinable());
+ _thread.join();
+ invariant(!_threadOpCtxHolder,
+ "Thread operation context is still alive after joining main thread");
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::shutdown() {
+ stdx::lock_guard<Latch> lock(_mutex);
+ if (_state == kStopped)
+ return;
+
+ _state = kStopped;
+
+ if (_threadOpCtxHolder) {
+ stdx::lock_guard<Client> scopedClientLock(*_threadOpCtxHolder->getClient());
+ _threadOpCtxHolder->markKilled(ErrorCodes::Interrupted);
+ }
+}
+
+bool RangeDeleterService::ReadyRangeDeletionsProcessor::_stopRequested() const {
+ stdx::unique_lock<Latch> lock(_mutex);
+ return _state == kStopped;
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::emplaceRangeDeletion(
+ const RangeDeletionTask& rdt) {
+ stdx::unique_lock<Latch> lock(_mutex);
+ invariant(_state == kRunning);
+ _queue.push(rdt);
+ _condVar.notify_all();
+}
+
+void RangeDeleterService::ReadyRangeDeletionsProcessor::_completedRangeDeletion() {
+ stdx::unique_lock<Latch> lock(_mutex);
+ dassert(!_queue.empty());
+ _queue.pop();
+}
+
void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
Client::initThread(kRangeDeletionThreadName);
{
@@ -93,21 +137,22 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
cc().setSystemOperationKillableByStepdown(lk);
}
- auto opCtx = [&] {
- stdx::unique_lock<Latch> lock(_mutex);
+ {
+ stdx::lock_guard<Latch> lock(_mutex);
+ if (_state != kRunning) {
+ return;
+ }
_threadOpCtxHolder = cc().makeOperationContext();
- _condVar.notify_all();
- return (*_threadOpCtxHolder).get();
- }();
+ }
- opCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE();
+ auto opCtx = _threadOpCtxHolder.get();
ON_BLOCK_EXIT([this]() {
- stdx::unique_lock<Latch> lock(_mutex);
+ stdx::lock_guard<Latch> lock(_mutex);
_threadOpCtxHolder.reset();
});
- while (opCtx->checkForInterruptNoAssert().isOK()) {
+ while (!_stopRequested()) {
{
stdx::unique_lock<Latch> lock(_mutex);
try {
@@ -223,7 +268,7 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
// Release the thread only in case the operation context has been interrupted, as
// interruption only happens on shutdown/stepdown (this is fine because range
// deletions will be resumed on the next step up)
- if (!opCtx->checkForInterruptNoAssert().isOK()) {
+ if (_stopRequested()) {
break;
}
@@ -237,6 +282,17 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() {
}
}
+void RangeDeleterService::onStartup(OperationContext* opCtx) {
+ if (disableResumableRangeDeleter.load() ||
+ !feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
+ return;
+ }
+
+ auto opObserverRegistry =
+ checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver());
+ opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>());
+}
+
void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long term) {
if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
return;
@@ -249,22 +305,14 @@ void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long te
return;
}
+ // Wait until all tasks and thread from previous term drain
+ _joinAndResetState();
+
auto lock = _acquireMutexUnconditionally();
dassert(_state == kDown, "Service expected to be down before stepping up");
_state = kInitializing;
- if (_executor) {
- // Join previously shutted down executor before reinstantiating it
- _executor->join();
- _executor.reset();
- } else {
- // Initializing the op observer, only executed once at the first step-up
- auto opObserverRegistry =
- checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver());
- opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>());
- }
-
const std::string kExecName("RangeDeleterServiceExecutor");
auto net = executor::makeNetworkInterface(kExecName);
auto pool = std::make_unique<executor::NetworkInterfaceThreadPool>(net.get());
@@ -303,9 +351,14 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx
LOGV2(6834800, "Resubmitting range deletion tasks");
- ScopedRangeDeleterLock rangeDeleterLock(opCtx);
- DBDirectClient client(opCtx);
+ // The Scoped lock is needed to serialize with concurrent range deletions
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, MODE_S);
+ // The collection lock is needed to serialize with donors trying to
+ // schedule local range deletions by updating the 'pending' field
+ AutoGetCollection rangeDeletionLock(
+ opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S);
+ DBDirectClient client(opCtx);
int nRescheduledTasks = 0;
// (1) register range deletion tasks marked as "processing"
@@ -370,47 +423,55 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx
.semi();
}
-void RangeDeleterService::_stopService(bool joinExecutor) {
- if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) {
- return;
- }
+void RangeDeleterService::_joinAndResetState() {
+ invariant(_state == kDown);
+ // Join the thread spawned on step-up to resume range deletions
+ _stepUpCompletedFuture.getNoThrow().ignore();
- {
- auto lock = _acquireMutexUnconditionally();
- _state = kDown;
- if (_initOpCtxHolder) {
- stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient());
- _initOpCtxHolder->markKilled(ErrorCodes::Interrupted);
- }
+ // Join and destruct the executor
+ if (_executor) {
+ _executor->join();
+ _executor.reset();
}
- // Join the thread spawned on step-up to resume range deletions
- _stepUpCompletedFuture.getNoThrow().ignore();
+ // Join and destruct the processor
+ _readyRangeDeletionsProcessorPtr.reset();
+
+ // Clear range deletions potentially created during recovery
+ _rangeDeletionTasks.clear();
+}
+void RangeDeleterService::_stopService() {
auto lock = _acquireMutexUnconditionally();
+ if (_state == kDown)
+ return;
+
+ _state = kDown;
+ if (_initOpCtxHolder) {
+ stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient());
+ _initOpCtxHolder->markKilled(ErrorCodes::Interrupted);
+ }
- // It may happen for the `onStepDown` hook to be invoked on a SECONDARY node transitioning
- // to ROLLBACK, hence the executor may have never been initialized
if (_executor) {
_executor->shutdown();
- if (joinExecutor) {
- _executor->join();
- }
}
- // Destroy the range deletion processor in order to stop range deletions
- _readyRangeDeletionsProcessorPtr.reset();
+ // Shutdown the range deletion processor to interrupt range deletions
+ if (_readyRangeDeletionsProcessorPtr) {
+ _readyRangeDeletionsProcessorPtr->shutdown();
+ }
// Clear range deletion tasks map in order to notify potential waiters on completion futures
_rangeDeletionTasks.clear();
}
void RangeDeleterService::onStepDown() {
- _stopService(false /* joinExecutor */);
+ _stopService();
}
void RangeDeleterService::onShutdown() {
- _stopService(true /* joinExecutor */);
+ _stopService();
+ _joinAndResetState();
}
BSONObj RangeDeleterService::dumpState() {
@@ -472,10 +533,9 @@ SharedSemiFuture<void> RangeDeleterService::registerTask(
.then([this, rdt = rdt]() {
// Step 3: schedule the actual range deletion task
auto lock = _acquireMutexUnconditionally();
- invariant(
- _readyRangeDeletionsProcessorPtr || _state == kDown,
- "The range deletions processor must be instantiated if the state != kDown");
if (_state != kDown) {
+ invariant(_readyRangeDeletionsProcessorPtr,
+ "The range deletions processor is not initialized");
_readyRangeDeletionsProcessorPtr->emplaceRangeDeletion(rdt);
}
});
diff --git a/src/mongo/db/s/range_deleter_service.h b/src/mongo/db/s/range_deleter_service.h
index 68cca97454c..c816c8b9db2 100644
--- a/src/mongo/db/s/range_deleter_service.h
+++ b/src/mongo/db/s/range_deleter_service.h
@@ -110,58 +110,37 @@ private:
*/
class ReadyRangeDeletionsProcessor {
public:
- ReadyRangeDeletionsProcessor(OperationContext* opCtx) {
- _thread = stdx::thread([this] { _runRangeDeletions(); });
- stdx::unique_lock<Latch> lock(_mutex);
- opCtx->waitForConditionOrInterrupt(
- _condVar, lock, [&] { return _threadOpCtxHolder.is_initialized(); });
- }
-
- ~ReadyRangeDeletionsProcessor() {
- {
- stdx::unique_lock<Latch> lock(_mutex);
- // The `_threadOpCtxHolder` may have been already reset/interrupted in case the
- // thread got interrupted due to stepdown
- if (_threadOpCtxHolder) {
- stdx::lock_guard<Client> scopedClientLock(*(*_threadOpCtxHolder)->getClient());
- if ((*_threadOpCtxHolder)->checkForInterruptNoAssert().isOK()) {
- (*_threadOpCtxHolder)->markKilled(ErrorCodes::Interrupted);
- }
- }
- _condVar.notify_all();
- }
+ ReadyRangeDeletionsProcessor(OperationContext* opCtx);
+ ~ReadyRangeDeletionsProcessor();
- if (_thread.joinable()) {
- _thread.join();
- }
- }
+ /*
+ * Interrupt ongoing range deletions
+ */
+ void shutdown();
/*
* Schedule a range deletion at the end of the queue
*/
- void emplaceRangeDeletion(const RangeDeletionTask& rdt) {
- stdx::unique_lock<Latch> lock(_mutex);
- _queue.push(rdt);
- _condVar.notify_all();
- }
+ void emplaceRangeDeletion(const RangeDeletionTask& rdt);
private:
/*
+ * Return true if this processor have been shutted down
+ */
+ bool _stopRequested() const;
+
+ /*
* Remove a range deletion from the head of the queue. Supposed to be called only once a
* range deletion successfully finishes.
*/
- void _completedRangeDeletion() {
- stdx::unique_lock<Latch> lock(_mutex);
- dassert(!_queue.empty());
- _queue.pop();
- }
+ void _completedRangeDeletion();
/*
* Code executed by the internal thread
*/
void _runRangeDeletions();
- Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor");
+ mutable Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor");
/*
* Condition variable notified when:
@@ -175,10 +154,13 @@ private:
std::queue<RangeDeletionTask> _queue;
/* Pointer to the (one and only) operation context used by the thread */
- boost::optional<ServiceContext::UniqueOperationContext> _threadOpCtxHolder;
+ ServiceContext::UniqueOperationContext _threadOpCtxHolder;
/* Thread consuming the range deletions queue */
stdx::thread _thread;
+
+ enum State { kRunning, kStopped };
+ State _state{kRunning};
};
// Keeping track of per-collection registered range deletion tasks
@@ -253,6 +235,7 @@ public:
const ChunkRange& range);
/* ReplicaSetAwareServiceShardSvr implemented methods */
+ void onStartup(OperationContext* opCtx) override;
void onStepUpComplete(OperationContext* opCtx, long long term) override;
void onStepDown() override;
void onShutdown() override;
@@ -276,17 +259,21 @@ public:
std::unique_ptr<ReadyRangeDeletionsProcessor> _readyRangeDeletionsProcessorPtr;
private:
+ /* Join all threads and executor and reset the in memory state of the service
+ * Used for onStartUpBegin and on onShutdown
+ */
+ void _joinAndResetState();
+
/* Asynchronously register range deletions on the service. To be called on on step-up */
void _recoverRangeDeletionsOnStepUp(OperationContext* opCtx);
- /* Called by shutdown/stepdown hooks to reset the service */
- void _stopService(bool joinExecutor);
+ /* Called by shutdown/stepdown hooks to interrupt the service */
+ void _stopService();
/* ReplicaSetAwareServiceShardSvr "empty implemented" methods */
- void onStartup(OperationContext* opCtx) override final{};
void onInitialDataAvailable(OperationContext* opCtx,
bool isMajorityDataAvailable) override final {}
- void onStepUpBegin(OperationContext* opCtx, long long term) override final {}
+ void onStepUpBegin(OperationContext* opCtx, long long term) override final{};
void onBecomeArbiter() override final {}
};
diff --git a/src/mongo/db/s/range_deleter_service_op_observer.cpp b/src/mongo/db/s/range_deleter_service_op_observer.cpp
index 9c1e51f9a6f..10bbbb6a924 100644
--- a/src/mongo/db/s/range_deleter_service_op_observer.cpp
+++ b/src/mongo/db/s/range_deleter_service_op_observer.cpp
@@ -35,6 +35,9 @@
#include "mongo/db/s/range_deleter_service.h"
#include "mongo/db/s/range_deletion_task_gen.h"
#include "mongo/db/update/update_oplog_entry_serialization.h"
+#include "mongo/logv2/log.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kShardingRangeDeleter
namespace mongo {
namespace {
@@ -54,10 +57,15 @@ void registerTaskWithOngoingQueriesOnOpLogEntryCommit(OperationContext* opCtx,
(void)RangeDeleterService::get(opCtx)->registerTask(
rdt, std::move(waitForActiveQueriesToComplete));
} catch (const DBException& ex) {
- dassert(ex.code() == ErrorCodes::NotYetInitialized,
- str::stream() << "No error different from `NotYetInitialized` is expected "
- "to be propagated to the range deleter observer. Got error: "
- << ex.toStatus());
+ if (ex.code() != ErrorCodes::NotYetInitialized &&
+ !ErrorCodes::isA<ErrorCategory::NotPrimaryError>(ex.code())) {
+ LOGV2_WARNING(7092800,
+ "No error different from `NotYetInitialized` or `NotPrimaryError` "
+ "category is expected to be propagated to the range deleter "
+ "observer. Range deletion task not registered.",
+ "error"_attr = redact(ex),
+ "task"_attr = rdt);
+ }
}
});
}
diff --git a/src/mongo/db/s/range_deleter_service_test.cpp b/src/mongo/db/s/range_deleter_service_test.cpp
index d2406120483..0807867da2d 100644
--- a/src/mongo/db/s/range_deleter_service_test.cpp
+++ b/src/mongo/db/s/range_deleter_service_test.cpp
@@ -45,6 +45,7 @@ void RangeDeleterServiceTest::setUp() {
ShardServerTestFixture::setUp();
WaitForMajorityService::get(getServiceContext()).startup(getServiceContext());
opCtx = operationContext();
+ RangeDeleterService::get(opCtx)->onStartup(opCtx);
RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L);
RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING();
@@ -275,12 +276,6 @@ TEST_F(RangeDeleterServiceTest, ScheduledTaskInvalidatedOnStepDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- // Re-enable the service for clean teardown
- rds->onStepUpComplete(opCtx, 0L);
- rds->_waitForRangeDeleterServiceUp_FOR_TESTING();
- });
-
try {
completionFuture.get(opCtx);
} catch (const ExceptionForCat<ErrorCategory::Interruption>&) {
@@ -293,12 +288,6 @@ TEST_F(RangeDeleterServiceTest, NoActionPossibleIfServiceIsDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- // Re-enable the service for clean teardown
- rds->onStepUpComplete(opCtx, 0L);
- rds->_waitForRangeDeleterServiceUp_FOR_TESTING();
- });
-
auto taskWithOngoingQueries = createRangeDeletionTaskWithOngoingQueries(
uuidCollA, BSON(kShardKey << 0), BSON(kShardKey << 10), CleanWhenEnum::kDelayed);
@@ -886,10 +875,6 @@ TEST_F(RangeDeleterServiceTest, WaitForOngoingQueriesInvalidatedOnStepDown) {
// Manually trigger disabling of the service
rds->onStepDown();
- ON_BLOCK_EXIT([&] {
- rds->onStepUpComplete(opCtx, 0L); // Re-enable the service
- });
-
try {
completionFuture.get(opCtx);
} catch (const ExceptionForCat<ErrorCategory::Interruption>&) {
diff --git a/src/mongo/db/s/range_deletion_util.cpp b/src/mongo/db/s/range_deletion_util.cpp
index 3a8b512ddba..707516db64f 100644
--- a/src/mongo/db/s/range_deletion_util.cpp
+++ b/src/mongo/db/s/range_deletion_util.cpp
@@ -322,7 +322,7 @@ Status deleteRangeInBatches(OperationContext* opCtx,
const ChunkRange& range) {
suspendRangeDeletion.pauseWhileSet(opCtx);
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
bool allDocsRemoved = false;
// Delete all batches in this range unless a stepdown error occurs. Do not yield the
@@ -584,7 +584,7 @@ void persistUpdatedNumOrphans(OperationContext* opCtx,
const auto query = getQueryFilterForRangeDeletionTask(collectionUuid, range);
try {
PersistentTaskStore<RangeDeletionTask> store(NamespaceString::kRangeDeletionNamespace);
- ScopedRangeDeleterLock rangeDeleterLock(opCtx, collectionUuid);
+ ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_IX);
// The DBDirectClient will not retry WriteConflictExceptions internally while holding an X
// mode lock, so we need to retry at this level.
writeConflictRetry(
diff --git a/src/mongo/db/s/shard_server_op_observer.cpp b/src/mongo/db/s/shard_server_op_observer.cpp
index 744c79a1eca..ae059d19512 100644
--- a/src/mongo/db/s/shard_server_op_observer.cpp
+++ b/src/mongo/db/s/shard_server_op_observer.cpp
@@ -509,20 +509,24 @@ void ShardServerOpObserver::onModifyShardedCollectionGlobalIndexCatalogEntry(
IDLParserContext("onModifyShardedCollectionGlobalIndexCatalogEntry"),
indexDoc["entry"].Obj());
auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp();
- opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry](auto _) {
+ auto uuid = uassertStatusOK(
+ UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName]));
+ opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry, uuid](auto _) {
AutoGetCollection autoColl(opCtx, nss, MODE_IX);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, nss, CSRAcquisitionMode::kExclusive)
- ->addIndex(opCtx, indexEntry, indexVersion);
+ ->addIndex(opCtx, indexEntry, {uuid, indexVersion});
});
} else {
auto indexName = indexDoc["entry"][IndexCatalogType::kNameFieldName].str();
auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp();
- opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion](auto _) {
+ auto uuid = uassertStatusOK(
+ UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName]));
+ opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion, uuid](auto _) {
AutoGetCollection autoColl(opCtx, nss, MODE_IX);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, nss, CSRAcquisitionMode::kExclusive)
- ->removeIndex(opCtx, indexName, indexVersion);
+ ->removeIndex(opCtx, indexName, {uuid, indexVersion});
});
}
}
diff --git a/src/mongo/db/s/sharding_recovery_service.cpp b/src/mongo/db/s/sharding_recovery_service.cpp
index d47e4056534..02dab8f3773 100644
--- a/src/mongo/db/s/sharding_recovery_service.cpp
+++ b/src/mongo/db/s/sharding_recovery_service.cpp
@@ -532,7 +532,7 @@ void ShardingRecoveryService::recoverIndexesCatalog(OperationContext* opCtx) {
AutoGetCollection collLock(opCtx, nss, MODE_X);
CollectionShardingRuntime::assertCollectionLockedAndAcquire(
opCtx, collLock->ns(), CSRAcquisitionMode::kExclusive)
- ->addIndex(opCtx, indexEntry, indexVersion);
+ ->addIndex(opCtx, indexEntry, {indexEntry.getCollectionUUID(), indexVersion});
}
}
LOGV2_DEBUG(6686502, 2, "Recovered all index versions");
diff --git a/src/mongo/db/s/sharding_write_router_bm.cpp b/src/mongo/db/s/sharding_write_router_bm.cpp
index e92d8292052..89fa3ded04d 100644
--- a/src/mongo/db/s/sharding_write_router_bm.cpp
+++ b/src/mongo/db/s/sharding_write_router_bm.cpp
@@ -38,6 +38,7 @@
#include "mongo/db/s/collection_metadata.h"
#include "mongo/db/s/collection_sharding_runtime.h"
#include "mongo/db/s/collection_sharding_state_factory_shard.h"
+#include "mongo/db/s/collection_sharding_state_factory_standalone.h"
#include "mongo/db/s/operation_sharding_state.h"
#include "mongo/db/s/sharding_state.h"
#include "mongo/db/s/sharding_write_router.h"
@@ -214,6 +215,10 @@ void BM_UnshardedDestinedRecipient(benchmark::State& state) {
serviceContext->registerClientObserver(std::make_unique<LockerNoopClientObserver>());
const auto opCtx = client->makeOperationContext();
+ CollectionShardingStateFactory::set(
+ opCtx->getServiceContext(),
+ std::make_unique<CollectionShardingStateFactoryStandalone>(opCtx->getServiceContext()));
+
const auto catalogCache = CatalogCacheMock::make();
for (auto keepRunning : state) {
diff --git a/src/mongo/db/s/transaction_coordinator_util.cpp b/src/mongo/db/s/transaction_coordinator_util.cpp
index 736e5fb1cf4..9bd3114a84f 100644
--- a/src/mongo/db/s/transaction_coordinator_util.cpp
+++ b/src/mongo/db/s/transaction_coordinator_util.cpp
@@ -445,7 +445,7 @@ Future<repl::OpTime> persistDecision(txn::AsyncWorkScheduler& scheduler,
// Do not acquire a storage ticket in order to avoid unnecessary serialization
// with other prepared transactions that are holding a storage ticket
// themselves; see SERVER-60682.
- SetTicketAquisitionPriorityForLock setTicketAquisition(
+ SetAdmissionPriorityForLock setTicketAquisition(
opCtx, AdmissionContext::Priority::kImmediate);
getTransactionCoordinatorWorkerCurOpRepository()->set(
opCtx, lsid, txnNumberAndRetryCounter, CoordinatorAction::kWritingDecision);
diff --git a/src/mongo/db/server_feature_flags.idl b/src/mongo/db/server_feature_flags.idl
index b391bb5a749..c02079f3f37 100644
--- a/src/mongo/db/server_feature_flags.idl
+++ b/src/mongo/db/server_feature_flags.idl
@@ -62,4 +62,7 @@ feature_flags:
cpp_varname: gFeatureFlagUseNewCompactStructuredEncryptionDataCoordinator
default: true
version: 6.1
-
+ featureFlagOIDC:
+ description: "Feature flag for OIDC support"
+ cpp_varname: gFeatureFlagOIDC
+ default: false
diff --git a/src/mongo/db/service_context.cpp b/src/mongo/db/service_context.cpp
index 013e3209f95..61e20dc82e2 100644
--- a/src/mongo/db/service_context.cpp
+++ b/src/mongo/db/service_context.cpp
@@ -100,7 +100,7 @@ void setGlobalServiceContext(ServiceContext::UniqueServiceContext&& serviceConte
ServiceContext::ServiceContext()
: _opIdRegistry(UniqueOperationIdRegistry::create()),
- _tickSource(std::make_unique<SystemTickSource>()),
+ _tickSource(makeSystemTickSource()),
_fastClockSource(std::make_unique<SystemClockSource>()),
_preciseClockSource(std::make_unique<SystemClockSource>()) {}
diff --git a/src/mongo/db/service_entry_point_common.cpp b/src/mongo/db/service_entry_point_common.cpp
index 0c4766739e4..84f46d900dd 100644
--- a/src/mongo/db/service_entry_point_common.cpp
+++ b/src/mongo/db/service_entry_point_common.cpp
@@ -2105,7 +2105,20 @@ DbResponse makeCommandResponse(std::shared_ptr<HandleRequest::ExecutionContext>
}
}
- dbResponse.response = replyBuilder->done();
+ try {
+ dbResponse.response = replyBuilder->done();
+ } catch (const ExceptionFor<ErrorCodes::BSONObjectTooLarge>& ex) {
+ // Create a new reply builder as subsequently calling any methods on a builder after
+ // 'done()' results in undefined behavior.
+ auto errorReplyBuilder = execContext->getReplyBuilder();
+ BSONObjBuilder metadataBob;
+ BSONObjBuilder extraFieldsBuilder;
+ appendClusterAndOperationTime(
+ opCtx, &extraFieldsBuilder, &metadataBob, LogicalTime::kUninitialized);
+ generateErrorResponse(
+ opCtx, errorReplyBuilder, ex.toStatus(), metadataBob.obj(), extraFieldsBuilder.obj());
+ dbResponse.response = errorReplyBuilder->done();
+ }
CurOp::get(opCtx)->debug().responseLength = dbResponse.response.header().dataLen();
return dbResponse;
@@ -2347,10 +2360,26 @@ BSONObj ServiceEntryPointCommon::getRedactedCopyForLogging(const Command* comman
return bob.obj();
}
-void onHandleRequestException(const Status& status) {
+void logHandleRequestFailure(const Status& status) {
LOGV2_ERROR(4879802, "Failed to handle request", "error"_attr = redact(status));
}
+void onHandleRequestException(const HandleRequest& hr, const Status& status) {
+ auto isMirrorOp = [&] {
+ const auto& obj = hr.executionContext->getRequest().body;
+ if (auto e = obj.getField("mirrored"); MONGO_unlikely(e.ok() && e.boolean()))
+ return true;
+ return false;
+ };
+
+ // TODO SERVER-70510 revert changes introduced by SERVER-60553 that suppresses errors occurred
+ // during handling of mirroring operations on recovering secondaries.
+ if (MONGO_unlikely(status == ErrorCodes::NotWritablePrimary && isMirrorOp()))
+ return;
+
+ logHandleRequestFailure(status);
+}
+
Future<DbResponse> ServiceEntryPointCommon::handleRequest(
OperationContext* opCtx,
const Message& m,
@@ -2362,7 +2391,7 @@ Future<DbResponse> ServiceEntryPointCommon::handleRequest(
invariant(opRunner);
return opRunner->run()
- .then([hr = std::move(hr)](DbResponse response) mutable {
+ .then([&hr](DbResponse response) mutable {
hr.completeOperation(response);
auto opCtx = hr.executionContext->getOpCtx();
@@ -2377,10 +2406,10 @@ Future<DbResponse> ServiceEntryPointCommon::handleRequest(
return response;
})
- .tapError([](Status status) { onHandleRequestException(status); });
+ .tapError([hr = std::move(hr)](Status status) { onHandleRequestException(hr, status); });
} catch (const DBException& ex) {
auto status = ex.toStatus();
- onHandleRequestException(status);
+ logHandleRequestFailure(status);
return status;
}
diff --git a/src/mongo/db/stats/SConscript b/src/mongo/db/stats/SConscript
index 784019b4532..f7745f59c9a 100644
--- a/src/mongo/db/stats/SConscript
+++ b/src/mongo/db/stats/SConscript
@@ -128,6 +128,7 @@ env.Library(
'$BUILD_DIR/mongo/db/commands/server_status_core',
'$BUILD_DIR/mongo/db/index/index_access_method',
'$BUILD_DIR/mongo/db/pipeline/document_sources_idl',
+ '$BUILD_DIR/mongo/db/s/balancer_stats_registry',
'$BUILD_DIR/mongo/db/server_base',
'$BUILD_DIR/mongo/db/shard_role',
'$BUILD_DIR/mongo/db/timeseries/bucket_catalog',
diff --git a/src/mongo/db/storage/external_record_store.cpp b/src/mongo/db/storage/external_record_store.cpp
index e63286fbf6a..2a4bb9a8796 100644
--- a/src/mongo/db/storage/external_record_store.cpp
+++ b/src/mongo/db/storage/external_record_store.cpp
@@ -33,8 +33,10 @@
#include "mongo/db/storage/record_store.h"
namespace mongo {
+// 'ident' is an identifer to WT table and a virtual collection does not have any persistent data
+// in WT. So, we set the "dummy" ident for a virtual collection.
ExternalRecordStore::ExternalRecordStore(StringData ns, const VirtualCollectionOptions& vopts)
- : RecordStore(ns, /*identName=*/ns, /*isCapped=*/false), _vopts(vopts) {}
+ : RecordStore(ns, /*identName=*/"dummy"_sd, /*isCapped=*/false), _vopts(vopts) {}
/**
* Returns a MultiBsonStreamCursor for this record store. Reverse scans are not currently supported
@@ -42,7 +44,6 @@ ExternalRecordStore::ExternalRecordStore(StringData ns, const VirtualCollectionO
*/
std::unique_ptr<SeekableRecordCursor> ExternalRecordStore::getCursor(OperationContext* opCtx,
bool forward) const {
-
if (forward) {
return std::make_unique<MultiBsonStreamCursor>(getOptions());
}
diff --git a/src/mongo/db/storage/external_record_store.h b/src/mongo/db/storage/external_record_store.h
index 152d766e6cb..ac0c61392c6 100644
--- a/src/mongo/db/storage/external_record_store.h
+++ b/src/mongo/db/storage/external_record_store.h
@@ -33,16 +33,8 @@
#include "mongo/db/catalog/virtual_collection_options.h"
#include "mongo/db/storage/record_store.h"
#include "mongo/util/assert_util.h"
-#include "mongo/util/errno_util.h"
namespace mongo {
-namespace {
-inline std::string getErrorMessage(StringData op, const std::string& path) {
- using namespace fmt::literals;
- return "Failed to {} {}: {}"_format(op, path, errorMessage(lastSystemError()));
-}
-} // namespace
-
class ExternalRecordStore : public RecordStore {
public:
ExternalRecordStore(StringData ns, const VirtualCollectionOptions& vopts);
diff --git a/src/mongo/db/storage/external_record_store_test.cpp b/src/mongo/db/storage/external_record_store_test.cpp
index f4c42fb03ad..e5474f43133 100644
--- a/src/mongo/db/storage/external_record_store_test.cpp
+++ b/src/mongo/db/storage/external_record_store_test.cpp
@@ -32,11 +32,9 @@
#include <ctime>
#include <fmt/format.h>
-#include "mongo/db/storage/external_record_store.h"
#include "mongo/db/storage/input_stream.h"
#include "mongo/db/storage/multi_bson_stream_cursor.h"
#include "mongo/db/storage/named_pipe.h"
-#include "mongo/logv2/log.h"
#include "mongo/stdx/thread.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
@@ -45,21 +43,33 @@
namespace mongo {
using namespace fmt::literals;
-#ifndef _WIN32
-static const std::string pipePath1 = "/tmp/named_pipe1";
-static const std::string pipePath2 = "/tmp/named_pipe2";
-static const std::string nonExistingPath = "/tmp/non-existing";
-#else
-// "//./pipe" is the required path start of all named pipes on Windows, where "//." is the
-// abbreviation for the local server name and "/pipe" is a literal. (These also work with
-// Windows-native backslashes instead of forward slashes.)
-static const std::string pipePath1 = R"(//./pipe/named_pipe1)";
-static const std::string pipePath2 = R"(//./pipe/named_pipe2)";
-static const std::string nonExistingPath = R"(//./pipe/non-existing)";
-#endif
+static const std::string pipePath1 = "named_pipe1";
+static const std::string pipePath2 = "named_pipe2";
+static const std::string nonExistingPath = "non-existing";
static constexpr int kNumPipes = 2;
static const std::string pipePaths[kNumPipes] = {pipePath1, pipePath2};
+class PipeWaiter {
+public:
+ void notify() {
+ {
+ stdx::unique_lock lk(m);
+ pipeCreated = true;
+ }
+ cv.notify_one();
+ }
+
+ void wait() {
+ stdx::unique_lock lk(m);
+ cv.wait(lk, [&] { return pipeCreated; });
+ }
+
+private:
+ Mutex m;
+ stdx::condition_variable cv;
+ bool pipeCreated = false;
+};
+
class ExternalRecordStoreTest : public unittest::Test {
public:
// Gets a random string of 'count' length consisting of printable ASCII chars (32-126).
@@ -74,11 +84,11 @@ public:
return buf;
}
-
static constexpr int kBufferSize = 1024;
char _buffer[kBufferSize]; // buffer amply big enough to fit any BSONObj used in this test
- static void createNamedPipe(const std::string& pipePath,
+ static void createNamedPipe(PipeWaiter* pw,
+ const std::string& pipePath,
long numToWrite,
const std::vector<BSONObj>& bsonObjs);
@@ -88,13 +98,16 @@ public:
};
// Creates a named pipe of BSON objects.
+// pipeWaiter - synchronization for pipe creation
// pipePath - file path for the named pipe
// numToWrite - number of bsons to write to the pipe
// bsonObjs - vector of bsons to write round-robin to the pipe
-void ExternalRecordStoreTest::createNamedPipe(const std::string& pipePath,
+void ExternalRecordStoreTest::createNamedPipe(PipeWaiter* pw,
+ const std::string& pipePath,
long numToWrite,
const std::vector<BSONObj>& bsonObjs) {
- NamedPipeOutput pipeWriter(pipePath.c_str());
+ NamedPipeOutput pipeWriter(pipePath);
+ pw->notify();
pipeWriter.open();
const int numObjs = bsonObjs.size();
@@ -110,8 +123,10 @@ void ExternalRecordStoreTest::createNamedPipe(const std::string& pipePath,
TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) {
auto srcBsonObj = BSON("a" << 1);
auto count = srcBsonObj.objsize();
+ PipeWaiter pw;
stdx::thread producer([&] {
- NamedPipeOutput pipeWriter(pipePath1.c_str());
+ NamedPipeOutput pipeWriter(pipePath1);
+ pw.notify();
pipeWriter.open();
for (int i = 0; i < 100; ++i) {
@@ -123,11 +138,11 @@ TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) {
ON_BLOCK_EXIT([&] { producer.join(); });
// Gives some time to the producer so that it can initialize a named pipe.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ pw.wait();
- auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str());
+ auto inputStream = InputStream<NamedPipeInput>(pipePath1);
for (int i = 0; i < 100; ++i) {
- int nRead = inputStream->readBytes(count, _buffer);
+ int nRead = inputStream.readBytes(count, _buffer);
ASSERT_EQ(nRead, count) << "Failed to read data up to {} bytes"_format(count);
ASSERT_EQ(std::memcmp(srcBsonObj.objdata(), _buffer, count), 0)
<< "Read data is not same as the source data";
@@ -137,8 +152,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) {
TEST_F(ExternalRecordStoreTest, NamedPipeReadPartialData) {
auto srcBsonObj = BSON("a" << 1);
auto count = srcBsonObj.objsize();
+ PipeWaiter pw;
stdx::thread producer([&] {
- NamedPipeOutput pipeWriter(pipePath1.c_str());
+ NamedPipeOutput pipeWriter(pipePath1);
+ pw.notify();
pipeWriter.open();
pipeWriter.write(srcBsonObj.objdata(), count);
pipeWriter.close();
@@ -146,11 +163,11 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadPartialData) {
ON_BLOCK_EXIT([&] { producer.join(); });
// Gives some time to the producer so that it can initialize a named pipe.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ pw.wait();
- auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str());
- // Request more data than the pipe contains. Should only get the bytes it does contain.
- int nRead = inputStream->readBytes(kBufferSize, _buffer);
+ auto inputStream = InputStream<NamedPipeInput>(pipePath1);
+ // Requests more data than the pipe contains. Should only get the bytes it does contain.
+ int nRead = inputStream.readBytes(kBufferSize, _buffer);
ASSERT_EQ(nRead, count) << "Expected nRead == {} but got {}"_format(count, nRead);
ASSERT_EQ(std::memcmp(srcBsonObj.objdata(), _buffer, count), 0)
<< "Read data is not same as the source data";
@@ -160,8 +177,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) {
auto srcBsonObj = BSON("a" << 1);
auto count = srcBsonObj.objsize();
const auto nSent = std::rand() % 100;
+ PipeWaiter pw;
stdx::thread producer([&] {
- NamedPipeOutput pipeWriter(pipePath1.c_str());
+ NamedPipeOutput pipeWriter(pipePath1);
+ pw.notify();
pipeWriter.open();
for (int i = 0; i < nSent; ++i) {
@@ -173,12 +192,12 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) {
ON_BLOCK_EXIT([&] { producer.join(); });
// Gives some time to the producer so that it can initialize a named pipe.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ pw.wait();
- auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str());
+ auto inputStream = InputStream<NamedPipeInput>(pipePath1);
auto nReceived = 0;
while (true) {
- int nRead = inputStream->readBytes(count, _buffer);
+ int nRead = inputStream.readBytes(count, _buffer);
if (nRead != count) {
ASSERT_EQ(nRead, 0) << "Expected nRead == 0 for EOF but got something else {}"_format(
nRead);
@@ -195,7 +214,7 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) {
TEST_F(ExternalRecordStoreTest, NamedPipeOpenNonExisting) {
ASSERT_THROWS_CODE(
- [] { (void)std::make_unique<InputStream<NamedPipeInput>>(nonExistingPath.c_str()); }(),
+ [] { (void)std::make_unique<InputStream<NamedPipeInput>>(nonExistingPath); }(),
DBException,
ErrorCodes::FileNotOpen);
}
@@ -210,9 +229,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes1) {
// Create two pipes. The first has only "a" objects and the second has only "zed" objects.
stdx::thread pipeThreads[kNumPipes];
+ PipeWaiter pw[kNumPipes];
for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
- pipeThreads[pipeIdx] =
- stdx::thread(createNamedPipe, pipePaths[pipeIdx], kObjsPerPipe, bsonObjs[pipeIdx]);
+ pipeThreads[pipeIdx] = stdx::thread(
+ createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], kObjsPerPipe, bsonObjs[pipeIdx]);
}
ON_BLOCK_EXIT([&] {
for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
@@ -221,7 +241,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes1) {
});
// Gives some time to the producers so they can initialize the named pipes.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
+ pw[pipeIdx].wait();
+ }
// Create metadata describing the pipes and a MultiBsonStreamCursor to read them.
VirtualCollectionOptions vopts;
@@ -302,12 +324,13 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes2) {
// so they will cause several wraps. The largish size of the objects makes it highly likely that
// some reads will leave a partial object that must be completed on a later next() call.
stdx::thread pipeThreads[kNumPipes];
+ PipeWaiter pw[kNumPipes];
long numToWrites[] = {(3 * groupsIn32Mb * numObjs), (5 * groupsIn32Mb * numObjs)};
long numToWrite = 0;
for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
- pipeThreads[pipeIdx] =
- stdx::thread(createNamedPipe, pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs);
+ pipeThreads[pipeIdx] = stdx::thread(
+ createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs);
numToWrite += numToWrites[pipeIdx];
}
ON_BLOCK_EXIT([&] {
@@ -317,7 +340,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes2) {
});
// Gives some time to the producers so they can initialize the named pipes.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
+ pw[pipeIdx].wait();
+ }
// Create metadata describing the pipes and a MultiBsonStreamCursor to read them.
VirtualCollectionOptions vopts;
@@ -383,12 +408,13 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes3) {
// Create pipes with large bsons.
stdx::thread pipeThreads[kNumPipes];
+ PipeWaiter pw[kNumPipes];
long numToWrites[] = {19, 17};
long numToWrite = 0;
for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
- pipeThreads[pipeIdx] =
- stdx::thread(createNamedPipe, pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs);
+ pipeThreads[pipeIdx] = stdx::thread(
+ createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs);
numToWrite += numToWrites[pipeIdx];
}
ON_BLOCK_EXIT([&] {
@@ -398,7 +424,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes3) {
});
// Gives some time to the producers so they can initialize the named pipes.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
+ pw[pipeIdx].wait();
+ }
// Create metadata describing the pipes and a MultiBsonStreamCursor to read them.
VirtualCollectionOptions vopts;
@@ -450,6 +478,7 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) {
constexpr int kNumPipes = 100; // shadows the global
std::string pipePaths[kNumPipes]; // shadows the global
stdx::thread pipeThreads[kNumPipes]; // pipe producer threads
+ PipeWaiter pw[kNumPipes]; // pipe waiters
std::vector<BSONObj> pipeBsonObjs[kNumPipes]; // vector of BSON objects for each pipe
size_t objsWritten = 0; // number of objects written to all pipes
@@ -465,12 +494,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) {
// Create the pipes.
for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
-#ifndef _WIN32
- pipePaths[pipeIdx] = "/tmp/named_pipe{}"_format(pipeIdx);
-#else
- pipePaths[pipeIdx] = "//./pipe/named_pipe{}"_format(pipeIdx);
-#endif
+ pipePaths[pipeIdx] = "named_pipe{}"_format(pipeIdx);
pipeThreads[pipeIdx] = stdx::thread(createNamedPipe,
+ &pw[pipeIdx],
pipePaths[pipeIdx],
pipeBsonObjs[pipeIdx].size(),
pipeBsonObjs[pipeIdx]);
@@ -482,7 +508,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) {
});
// Gives some time to the producers so they can initialize the named pipes.
- stdx::this_thread::sleep_for(stdx::chrono::seconds(1));
+ for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) {
+ pw[pipeIdx].wait();
+ }
// Create metadata describing the pipes and a MultiBsonStreamCursor to read them.
VirtualCollectionOptions vopts;
diff --git a/src/mongo/db/storage/input_object.h b/src/mongo/db/storage/input_object.h
index 9f420b8288b..5788d5c7b0b 100644
--- a/src/mongo/db/storage/input_object.h
+++ b/src/mongo/db/storage/input_object.h
@@ -41,7 +41,7 @@ class StreamableInput {
public:
virtual ~StreamableInput() {}
- virtual const char* getPath() const = 0;
+ virtual const std::string& getAbsolutePath() const = 0;
void open() {
if (isOpen()) {
diff --git a/src/mongo/db/storage/input_stream.h b/src/mongo/db/storage/input_stream.h
index 95d0b6d9bd8..45f33522f80 100644
--- a/src/mongo/db/storage/input_stream.h
+++ b/src/mongo/db/storage/input_stream.h
@@ -32,12 +32,12 @@
#include <fmt/format.h>
#include <utility>
+#include "mongo/db/storage/io_error_message.h"
#include "mongo/logv2/log.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
namespace mongo {
-
/**
* This template class provides a standardized input facility over StreamableInput or SeekableInput.
*
@@ -57,7 +57,7 @@ public:
using namespace fmt::literals;
InputT::open();
uassert(ErrorCodes::FileNotOpen,
- "error"_format(getErrorMessage("open"_sd, InputT::getPath())),
+ "error"_format(getErrorMessage("open"_sd, InputT::getAbsolutePath())),
InputT::isOpen());
}
@@ -95,14 +95,14 @@ public:
// If we reach this point, we accumulated fewer than 'count' bytes.
if (MONGO_likely(InputT::isEof())) {
- LOGV2_INFO(7005001, "Named pipe is closed", "path"_attr = InputT::getPath());
+ LOGV2_INFO(7005001, "Named pipe is closed", "path"_attr = InputT::getAbsolutePath());
return nReadTotal;
}
tassert(7005002, "Expected an error condition but succeeded", InputT::isFailed());
LOGV2_ERROR(7005003,
"Failed to read a named pipe",
- "error"_attr = getErrorMessage("read", InputT::getPath()));
+ "error"_attr = getErrorMessage("read", InputT::getAbsolutePath()));
return -1;
}
diff --git a/src/mongo/db/storage/io_error_message.h b/src/mongo/db/storage/io_error_message.h
new file mode 100644
index 00000000000..37979358286
--- /dev/null
+++ b/src/mongo/db/storage/io_error_message.h
@@ -0,0 +1,43 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <fmt/format.h>
+
+#include "mongo/util/errno_util.h"
+
+namespace mongo {
+namespace {
+inline std::string getErrorMessage(StringData op, const std::string& path) {
+ using namespace fmt::literals;
+ return "Failed to {} {}: {}"_format(op, path, errorMessage(lastSystemError()));
+}
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/storage/multi_bson_stream_cursor.cpp b/src/mongo/db/storage/multi_bson_stream_cursor.cpp
index 189131b0208..749a21e2aff 100644
--- a/src/mongo/db/storage/multi_bson_stream_cursor.cpp
+++ b/src/mongo/db/storage/multi_bson_stream_cursor.cpp
@@ -28,7 +28,10 @@
*/
#include "mongo/db/storage/multi_bson_stream_cursor.h"
+
+#include "mongo/db/catalog/virtual_collection_options.h"
#include "mongo/db/storage/record_store.h"
+
namespace mongo {
/**
@@ -37,10 +40,10 @@ namespace mongo {
* and updates bookkeeping. This can never expand the buffer larger than (2 * BSONObjMaxUserSize).
*/
void MultiBsonStreamCursor::expandBuffer(int32_t bsonSize) {
- tassert(6968308,
+ uassert(6968308,
"bsonSize {} > BSONObjMaxUserSize {}"_format(bsonSize, BSONObjMaxUserSize),
(bsonSize <= BSONObjMaxUserSize));
- tassert(6968309, "bsonSize {} < 0"_format(bsonSize), (bsonSize >= 0));
+ uassert(6968309, "bsonSize {} < 0"_format(bsonSize), (bsonSize >= 0));
int newSizeTarget = 2 * bsonSize;
do {
@@ -143,6 +146,23 @@ boost::optional<Record> MultiBsonStreamCursor::nextFromCurrentStream() {
}
/**
+ * Returns an input stream for a named pipe mapped from 'url'.
+ *
+ * While creating an input stream, it strips off the file protocol part from the 'url'.
+ */
+std::unique_ptr<InputStream<NamedPipeInput>> MultiBsonStreamCursor::getInputStream(
+ const std::string& url) {
+ auto filePathPos = url.find(ExternalDataSourceMetadata::kUrlProtocolFile.toString());
+ tassert(
+ ErrorCodes::BadValue, "Invalid file url: {}"_format(url), filePathPos != std::string::npos);
+
+ auto filePathStr =
+ url.substr(filePathPos + ExternalDataSourceMetadata::kUrlProtocolFile.size());
+
+ return std::make_unique<InputStream<NamedPipeInput>>(filePathStr);
+}
+
+/**
* Returns the next record from the vector of streams or boost::none if exhausted or error.
* '_streamReader' is initialized to the first stream, if there is one, in the constructor.
*/
@@ -154,8 +174,7 @@ boost::optional<Record> MultiBsonStreamCursor::next() {
}
++_streamIdx;
if (_streamIdx < _numStreams) {
- _streamReader = std::make_unique<InputStream<NamedPipeInput>>(
- _vopts.dataSources[_streamIdx].url.c_str());
+ _streamReader = getInputStream(_vopts.dataSources[_streamIdx].url);
}
}
return boost::none;
diff --git a/src/mongo/db/storage/multi_bson_stream_cursor.h b/src/mongo/db/storage/multi_bson_stream_cursor.h
index 1b896f6a6c5..d71e7240326 100644
--- a/src/mongo/db/storage/multi_bson_stream_cursor.h
+++ b/src/mongo/db/storage/multi_bson_stream_cursor.h
@@ -29,9 +29,10 @@
#pragma once
-#include "mongo/db/storage/external_record_store.h"
+#include "mongo/db/catalog/virtual_collection_options.h"
#include "mongo/db/storage/input_stream.h"
#include "mongo/db/storage/named_pipe.h"
+#include "mongo/db/storage/record_store.h"
namespace mongo {
class MultiBsonStreamCursor : public SeekableRecordCursor {
@@ -39,8 +40,7 @@ public:
MultiBsonStreamCursor(const VirtualCollectionOptions& vopts)
: _numStreams(vopts.dataSources.size()), _vopts(vopts) {
tassert(6968310, "_numStreams {} <= 0"_format(_numStreams), _numStreams > 0);
- _streamReader =
- std::make_unique<InputStream<NamedPipeInput>>(_vopts.dataSources[0].url.c_str());
+ _streamReader = getInputStream(_vopts.dataSources[_streamIdx].url);
}
boost::optional<Record> next() override;
@@ -71,6 +71,7 @@ public:
private:
void expandBuffer(int32_t bsonSize);
boost::optional<Record> nextFromCurrentStream();
+ static std::unique_ptr<InputStream<NamedPipeInput>> getInputStream(const std::string& url);
// The size in bytes of a BSON object's "size" prefix.
static constexpr int kSizeSize = static_cast<int>(sizeof(int32_t));
diff --git a/src/mongo/db/storage/named_pipe.h b/src/mongo/db/storage/named_pipe.h
index 61e6c87ac87..587661e5b6a 100644
--- a/src/mongo/db/storage/named_pipe.h
+++ b/src/mongo/db/storage/named_pipe.h
@@ -34,20 +34,30 @@
#else
#include <windows.h>
#endif
+#include <string>
#include "mongo/db/storage/input_object.h"
namespace mongo {
+#ifndef _WIN32
+static constexpr auto kDefaultPipePath = "/tmp/"_sd;
+#else
+// "//./pipe/" is the required path start of all named pipes on Windows, where "//." is the
+// abbreviation for the local server name and "/pipe" is a literal. (These also work with
+// Windows-native backslashes instead of forward slashes.
+static constexpr auto kDefaultPipePath = "//./pipe/"_sd;
+#endif
+
class NamedPipeOutput {
public:
- NamedPipeOutput(const char* pipePath);
+ NamedPipeOutput(const std::string& pipeRelativePath);
~NamedPipeOutput();
void open();
int write(const char* data, int size);
void close();
private:
- const char* _pipePath;
+ std::string _pipeAbsolutePath;
#ifndef _WIN32
std::ofstream _ofs;
#else
@@ -58,10 +68,10 @@ private:
class NamedPipeInput : public StreamableInput {
public:
- NamedPipeInput(const char* pipePath);
+ NamedPipeInput(const std::string& pipeRelativePath);
~NamedPipeInput() override;
- const char* getPath() const override {
- return _pipePath;
+ const std::string& getAbsolutePath() const override {
+ return _pipeAbsolutePath;
}
bool isOpen() const override;
bool isGood() const override;
@@ -74,7 +84,7 @@ protected:
void doClose() override;
private:
- const char* _pipePath;
+ std::string _pipeAbsolutePath;
#ifndef _WIN32
std::ifstream _ifs;
#else
diff --git a/src/mongo/db/storage/named_pipe_posix.cpp b/src/mongo/db/storage/named_pipe_posix.cpp
index c1e04be6f27..c34079fc916 100644
--- a/src/mongo/db/storage/named_pipe_posix.cpp
+++ b/src/mongo/db/storage/named_pipe_posix.cpp
@@ -31,10 +31,11 @@
#include "named_pipe.h"
#include <fmt/format.h>
+#include <string>
#include <sys/stat.h>
#include <sys/types.h>
-#include "mongo/db/storage/external_record_store.h"
+#include "mongo/db/storage/io_error_message.h"
#include "mongo/logv2/log.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
@@ -43,23 +44,26 @@
namespace mongo {
using namespace fmt::literals;
-NamedPipeOutput::NamedPipeOutput(const char* pipePath) : _pipePath(pipePath), _ofs() {
- remove(_pipePath);
+NamedPipeOutput::NamedPipeOutput(const std::string& pipeRelativePath)
+ : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), _ofs() {
+ remove(_pipeAbsolutePath.c_str());
uassert(7005005,
- "Failed to create a named pipe, error={}"_format(getErrorMessage("mkfifo", _pipePath)),
- mkfifo(_pipePath, 0664) == 0);
+ "Failed to create a named pipe, error={}"_format(
+ getErrorMessage("mkfifo", _pipeAbsolutePath)),
+ mkfifo(_pipeAbsolutePath.c_str(), 0664) == 0);
}
NamedPipeOutput::~NamedPipeOutput() {
close();
+ remove(_pipeAbsolutePath.c_str());
}
void NamedPipeOutput::open() {
- _ofs.open(_pipePath, std::ios::binary | std::ios::app);
+ _ofs.open(_pipeAbsolutePath.c_str(), std::ios::binary | std::ios::app);
if (!_ofs.is_open() || !_ofs.good()) {
LOGV2_ERROR(7005009,
"Failed to open a named pipe",
- "error"_attr = getErrorMessage("open", _pipePath));
+ "error"_attr = getErrorMessage("open", _pipeAbsolutePath));
}
}
@@ -78,14 +82,15 @@ void NamedPipeOutput::close() {
}
}
-NamedPipeInput::NamedPipeInput(const char* pipePath) : _pipePath(pipePath), _ifs() {}
+NamedPipeInput::NamedPipeInput(const std::string& pipeRelativePath)
+ : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), _ifs() {}
NamedPipeInput::~NamedPipeInput() {
close();
}
void NamedPipeInput::doOpen() {
- _ifs.open(_pipePath, std::ios::binary | std::ios::in);
+ _ifs.open(_pipeAbsolutePath.c_str(), std::ios::binary | std::ios::in);
}
int NamedPipeInput::doRead(char* data, int size) {
diff --git a/src/mongo/db/storage/named_pipe_windows.cpp b/src/mongo/db/storage/named_pipe_windows.cpp
index 8b0aab88987..aa4d530cab9 100644
--- a/src/mongo/db/storage/named_pipe_windows.cpp
+++ b/src/mongo/db/storage/named_pipe_windows.cpp
@@ -31,9 +31,12 @@
#include "named_pipe.h"
#include <fmt/format.h>
+#include <string>
+#include <system_error>
-#include "mongo/db/storage/external_record_store.h"
+#include "mongo/db/storage/io_error_message.h"
#include "mongo/logv2/log.h"
+#include "mongo/util/errno_util.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
@@ -41,14 +44,20 @@
namespace mongo {
using namespace fmt::literals;
-NamedPipeOutput::NamedPipeOutput(const char* pipePath)
- : _pipePath(pipePath),
- _pipe(CreateNamedPipeA(
- pipePath, PIPE_ACCESS_OUTBOUND, (PIPE_TYPE_BYTE | PIPE_WAIT), 1, 0, 0, 0, nullptr)),
+NamedPipeOutput::NamedPipeOutput(const std::string& pipeRelativePath)
+ : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath),
+ _pipe(CreateNamedPipeA(_pipeAbsolutePath.c_str(),
+ PIPE_ACCESS_OUTBOUND,
+ (PIPE_TYPE_BYTE | PIPE_WAIT),
+ 1, // nMaxInstances
+ 0, // nOutBufferSize
+ 0, // nInBufferSize
+ 0, // nDefaultTimeOut
+ nullptr)), // lpSecurityAttributes
_isOpen(false) {
uassert(7005006,
"Failed to create a named pipe, error={}"_format(
- getErrorMessage("CreateNamedPipe", _pipePath)),
+ getErrorMessage("CreateNamedPipe", _pipeAbsolutePath)),
_pipe != INVALID_HANDLE_VALUE);
}
@@ -64,7 +73,7 @@ void NamedPipeOutput::open() {
if (!res) {
LOGV2_ERROR(7005007,
"Failed to connect a named pipe",
- "error"_attr = getErrorMessage("ConnectNamedPipe", _pipePath));
+ "error"_attr = getErrorMessage("ConnectNamedPipe", _pipeAbsolutePath));
return;
}
_isOpen = true;
@@ -83,7 +92,7 @@ int NamedPipeOutput::write(const char* data, int size) {
if (!res || size != nWritten) {
LOGV2_ERROR(7005008,
"Failed to write to a named pipe",
- "error"_attr = getErrorMessage("write", _pipePath));
+ "error"_attr = getErrorMessage("write", _pipeAbsolutePath));
return -1;
}
@@ -102,8 +111,8 @@ void NamedPipeOutput::close() {
}
}
-NamedPipeInput::NamedPipeInput(const char* pipePath)
- : _pipePath(pipePath),
+NamedPipeInput::NamedPipeInput(const std::string& pipeRelativePath)
+ : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath),
_pipe(INVALID_HANDLE_VALUE),
_isOpen(false),
_isGood(false),
@@ -114,7 +123,8 @@ NamedPipeInput::~NamedPipeInput() {
}
void NamedPipeInput::doOpen() {
- _pipe = CreateFileA(_pipePath, GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr);
+ _pipe =
+ CreateFileA(_pipeAbsolutePath.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr);
if (_pipe == INVALID_HANDLE_VALUE) {
return;
}
diff --git a/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp b/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp
index 0627f0e41ce..582df818b5a 100644
--- a/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp
+++ b/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp
@@ -66,8 +66,7 @@ bool OplogCapMaintainerThread::_deleteExcessDocuments() {
// Maintaining the Oplog cap is crucial to the stability of the server so that we don't let the
// oplog grow unbounded. We mark the operation as having immediate priority to skip ticket
// acquisition and flow control.
- SetTicketAquisitionPriorityForLock priority(opCtx.get(),
- AdmissionContext::Priority::kImmediate);
+ SetAdmissionPriorityForLock priority(opCtx.get(), AdmissionContext::Priority::kImmediate);
try {
// A Global IX lock should be good enough to protect the oplog truncation from
diff --git a/src/mongo/db/storage/recovery_unit.h b/src/mongo/db/storage/recovery_unit.h
index 20c49c637d3..9e173d632e7 100644
--- a/src/mongo/db/storage/recovery_unit.h
+++ b/src/mongo/db/storage/recovery_unit.h
@@ -37,6 +37,7 @@
#include "mongo/bson/timestamp.h"
#include "mongo/db/repl/read_concern_level.h"
#include "mongo/db/storage/snapshot.h"
+#include "mongo/db/storage/storage_stats.h"
#include "mongo/util/decorable.h"
namespace mongo {
@@ -76,37 +77,6 @@ enum class PrepareConflictBehavior {
};
/**
- * Storage statistics management class, with interfaces to provide the statistics in the BSON format
- * and an operator to add the statistics values.
- */
-class StorageStats {
- StorageStats(const StorageStats&) = delete;
- StorageStats& operator=(const StorageStats&) = delete;
-
-public:
- StorageStats() = default;
-
- virtual ~StorageStats(){};
-
- /**
- * Provides the storage statistics in the form of a BSONObj.
- */
- virtual BSONObj toBSON() = 0;
-
- /**
- * Add the statistics values.
- */
- virtual StorageStats& operator+=(const StorageStats&) = 0;
-
- /**
- * Provides the ability to create an instance of this class outside of the storage integration
- * layer.
- */
- virtual std::shared_ptr<StorageStats> getCopy() = 0;
-};
-
-
-/**
* A RecoveryUnit is responsible for ensuring that data is persisted.
* All on-disk information must be mutated through this interface.
*/
diff --git a/src/mongo/db/storage/storage_engine_impl.cpp b/src/mongo/db/storage/storage_engine_impl.cpp
index b5b3b2f0d80..6eb5123cf92 100644
--- a/src/mongo/db/storage/storage_engine_impl.cpp
+++ b/src/mongo/db/storage/storage_engine_impl.cpp
@@ -971,10 +971,12 @@ Status StorageEngineImpl::_dropCollectionsNoTimestamp(OperationContext* opCtx,
audit::logDropCollection(opCtx->getClient(), coll->ns());
- Status result = catalog::dropCollection(
- opCtx, coll->ns(), coll->getCatalogId(), coll->getSharedIdent());
- if (!result.isOK() && firstError.isOK()) {
- firstError = result;
+ if (auto sharedIdent = coll->getSharedIdent()) {
+ Status result =
+ catalog::dropCollection(opCtx, coll->ns(), coll->getCatalogId(), sharedIdent);
+ if (!result.isOK() && firstError.isOK()) {
+ firstError = result;
+ }
}
CollectionCatalog::get(opCtx)->dropCollection(
diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.h b/src/mongo/db/storage/storage_stats.h
index b71a020316e..b0afd2ba16d 100644
--- a/src/mongo/db/query/optimizer/cascades/cost_derivation.h
+++ b/src/mongo/db/storage/storage_stats.h
@@ -29,22 +29,29 @@
#pragma once
-#include "mongo/db/query/optimizer/cascades/interfaces.h"
-#include "mongo/db/query/optimizer/cascades/memo.h"
+#include "mongo/bson/bsonobj.h"
-namespace mongo::optimizer::cascades {
+namespace mongo {
/**
- * Default costing for physical nodes with logical delegator (not-yet-optimized) inputs.
+ * Manages statistics from the storage engine, allowing addition of statistics and serialization to
+ * BSON.
*/
-class DefaultCosting : public CostingInterface {
+class StorageStats {
public:
- CostAndCE deriveCost(const Metadata& metadata,
- const Memo& memo,
- const properties::PhysProps& physProps,
- ABT::reference_type physNodeRef,
- const ChildPropsType& childProps,
- const NodeCEMap& nodeCEMap) const override final;
+ StorageStats() = default;
+
+ StorageStats(const StorageStats&) = delete;
+ StorageStats(StorageStats&&) = delete;
+ StorageStats& operator=(const StorageStats&) = delete;
+
+ virtual ~StorageStats() = default;
+
+ virtual BSONObj toBSON() const = 0;
+
+ virtual std::shared_ptr<StorageStats> clone() const = 0;
+
+ virtual StorageStats& operator+=(const StorageStats&) = 0;
};
-} // namespace mongo::optimizer::cascades
+} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/SConscript b/src/mongo/db/storage/wiredtiger/SConscript
index 68a76639ed5..638cf0b933c 100644
--- a/src/mongo/db/storage/wiredtiger/SConscript
+++ b/src/mongo/db/storage/wiredtiger/SConscript
@@ -37,6 +37,7 @@ wtEnv.Library(
'wiredtiger_index.cpp',
'wiredtiger_index_util.cpp',
'wiredtiger_kv_engine.cpp',
+ 'wiredtiger_operation_stats.cpp',
'wiredtiger_oplog_manager.cpp',
'wiredtiger_parameters.cpp',
'wiredtiger_prepare_conflict.cpp',
@@ -133,6 +134,7 @@ wtEnv.CppUnitTest(
'wiredtiger_init_test.cpp',
'wiredtiger_c_api_test.cpp',
'wiredtiger_kv_engine_test.cpp',
+ 'wiredtiger_operation_stats_test.cpp',
'wiredtiger_recovery_unit_test.cpp',
'wiredtiger_session_cache_test.cpp',
'wiredtiger_util_test.cpp',
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp
index 1f034c86e26..1a31948a00e 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp
@@ -39,7 +39,6 @@
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
-
namespace mongo {
StatusWith<std::string> WiredTigerColumnStore::generateCreateString(
const std::string& engineName,
@@ -218,9 +217,19 @@ void WiredTigerColumnStore::WriteCursor::update(PathView path, RowId rid, CellVi
void WiredTigerColumnStore::fullValidate(OperationContext* opCtx,
int64_t* numKeysOut,
IndexValidateResults* fullResults) const {
- // TODO SERVER-65484: Validation for column indexes.
- // uasserted(ErrorCodes::NotImplemented, "WiredTigerColumnStore::fullValidate()");
- return;
+ dassert(opCtx->lockState()->isReadLocked());
+ if (!WiredTigerIndexUtil::validateStructure(opCtx, _uri, fullResults)) {
+ return;
+ }
+ auto cursor = newCursor(opCtx);
+ long long count = 0;
+
+ while (cursor->next()) {
+ count++;
+ }
+ if (numKeysOut) {
+ *numKeysOut = count;
+ }
}
class WiredTigerColumnStore::Cursor final : public ColumnStore::Cursor,
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp
index a63022f4ad2..d5297781aaa 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp
@@ -298,32 +298,8 @@ void WiredTigerIndex::fullValidate(OperationContext* opCtx,
long long* numKeysOut,
IndexValidateResults* fullResults) const {
dassert(opCtx->lockState()->isReadLocked());
- if (fullResults && !WiredTigerRecoveryUnit::get(opCtx)->getSessionCache()->isEphemeral()) {
- int err = WiredTigerUtil::verifyTable(opCtx, _uri, &(fullResults->errors));
- if (err == EBUSY) {
- std::string msg = str::stream()
- << "Could not complete validation of " << _uri << ". "
- << "This is a transient issue as the collection was actively "
- "in use by other operations.";
-
- LOGV2_WARNING(51781,
- "Could not complete validation. This is a transient issue as "
- "the collection was actively in use by other operations",
- "uri"_attr = _uri);
- fullResults->warnings.push_back(msg);
- } else if (err) {
- std::string msg = str::stream()
- << "verify() returned " << wiredtiger_strerror(err) << ". "
- << "This indicates structural damage. "
- << "Not examining individual index entries.";
- LOGV2_ERROR(51782,
- "verify() returned an error. This indicates structural damage. Not "
- "examining individual index entries.",
- "error"_attr = wiredtiger_strerror(err));
- fullResults->errors.push_back(msg);
- fullResults->valid = false;
- return;
- }
+ if (!WiredTigerIndexUtil::validateStructure(opCtx, _uri, fullResults)) {
+ return;
}
auto cursor = newCursor(opCtx);
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp
index 97278595091..273b692c375 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp
@@ -30,6 +30,9 @@
#include "mongo/db/storage/wiredtiger/wiredtiger_index_util.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h"
+#include "mongo/logv2/log.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage
namespace mongo {
@@ -115,4 +118,37 @@ bool WiredTigerIndexUtil::isEmpty(OperationContext* opCtx,
return false;
}
+bool WiredTigerIndexUtil::validateStructure(OperationContext* opCtx,
+ const std::string& uri,
+ IndexValidateResults* fullResults) {
+ if (fullResults && !WiredTigerRecoveryUnit::get(opCtx)->getSessionCache()->isEphemeral()) {
+ int err = WiredTigerUtil::verifyTable(opCtx, uri, &(fullResults->errors));
+ if (err == EBUSY) {
+ std::string msg = str::stream()
+ << "Could not complete validation of " << uri << ". "
+ << "This is a transient issue as the collection was actively "
+ "in use by other operations.";
+
+ LOGV2_WARNING(51781,
+ "Could not complete validation. This is a transient issue as "
+ "the collection was actively in use by other operations",
+ "uri"_attr = uri);
+ fullResults->warnings.push_back(msg);
+ } else if (err) {
+ std::string msg = str::stream()
+ << "verify() returned " << wiredtiger_strerror(err) << ". "
+ << "This indicates structural damage. "
+ << "Not examining individual index entries.";
+ LOGV2_ERROR(51782,
+ "verify() returned an error. This indicates structural damage. Not "
+ "examining individual index entries.",
+ "error"_attr = wiredtiger_strerror(err));
+ fullResults->errors.push_back(msg);
+ fullResults->valid = false;
+ return false;
+ }
+ }
+ return true;
+}
+
} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h
index e98c5a7b1dc..e80067dd6d1 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h
@@ -31,6 +31,7 @@
#include "mongo/base/status.h"
#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/db/catalog/validate_results.h"
#include "mongo/db/operation_context.h"
namespace mongo {
@@ -51,6 +52,10 @@ public:
static Status compact(OperationContext* opCtx, const std::string& uri);
static bool isEmpty(OperationContext* opCtx, const std::string& uri, uint64_t tableId);
+
+ static bool validateStructure(OperationContext* opCtx,
+ const std::string& uri,
+ IndexValidateResults* fullResults);
};
} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp
new file mode 100644
index 00000000000..ebf3842e7bc
--- /dev/null
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp
@@ -0,0 +1,139 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h"
+
+#include "mongo/bson/bsonobjbuilder.h"
+#include "mongo/db/storage/wiredtiger/wiredtiger_util.h"
+
+namespace mongo {
+namespace {
+
+enum class StatType { kData, kWait };
+
+struct StatInfo {
+ StringData name;
+ StatType type;
+};
+
+stdx::unordered_map<int, StatInfo> statInfo = {
+ {WT_STAT_SESSION_BYTES_READ, {"bytesRead"_sd, StatType::kData}},
+ {WT_STAT_SESSION_BYTES_WRITE, {"bytesWritten"_sd, StatType::kData}},
+ {WT_STAT_SESSION_LOCK_DHANDLE_WAIT, {"handleLock"_sd, StatType::kWait}},
+ {WT_STAT_SESSION_READ_TIME, {"timeReadingMicros"_sd, StatType::kData}},
+ {WT_STAT_SESSION_WRITE_TIME, {"timeWritingMicros"_sd, StatType::kData}},
+ {WT_STAT_SESSION_LOCK_SCHEMA_WAIT, {"schemaLock"_sd, StatType::kWait}},
+ {WT_STAT_SESSION_CACHE_TIME, {"cache"_sd, StatType::kWait}}};
+
+} // namespace
+
+WiredTigerOperationStats::WiredTigerOperationStats(WT_SESSION* session) {
+ invariant(session);
+
+ WT_CURSOR* c;
+ uassert(ErrorCodes::CursorNotFound,
+ "Unable to open statistics cursor",
+ !session->open_cursor(session, "statistics:session", nullptr, "statistics=(fast)", &c));
+
+ ScopeGuard guard{[c] { c->close(c); }};
+
+ int32_t key;
+ uint64_t value;
+ while (c->next(c) == 0 && c->get_key(c, &key) == 0) {
+ fassert(51035, c->get_value(c, nullptr, nullptr, &value) == 0);
+ _stats[key] = WiredTigerUtil::castStatisticsValue<long long>(value);
+ }
+
+ // Reset the statistics so that the next fetch gives the recent values.
+ invariantWTOK(c->reset(c), c->session);
+}
+
+BSONObj WiredTigerOperationStats::toBSON() const {
+ boost::optional<BSONObjBuilder> dataSection;
+ boost::optional<BSONObjBuilder> waitSection;
+
+ for (auto&& [stat, value] : _stats) {
+ if (value == 0) {
+ continue;
+ }
+
+ auto it = statInfo.find(stat);
+ if (it == statInfo.end()) {
+ continue;
+ }
+ auto&& [name, type] = it->second;
+
+ auto appendToSection = [name = name,
+ value = value](boost::optional<BSONObjBuilder>& section) {
+ if (!section) {
+ section.emplace();
+ }
+ section->append(name, value);
+ };
+
+ switch (type) {
+ case StatType::kData:
+ appendToSection(dataSection);
+ break;
+ case StatType::kWait:
+ appendToSection(waitSection);
+ break;
+ }
+ }
+
+ BSONObjBuilder builder;
+ if (dataSection) {
+ builder.append("data", dataSection->obj());
+ }
+ if (waitSection) {
+ builder.append("timeWaitingMicros", waitSection->obj());
+ }
+
+ return builder.obj();
+}
+
+std::shared_ptr<StorageStats> WiredTigerOperationStats::clone() const {
+ auto copy = std::make_shared<WiredTigerOperationStats>();
+ *copy += *this;
+ return copy;
+}
+
+WiredTigerOperationStats& WiredTigerOperationStats::operator+=(
+ const WiredTigerOperationStats& other) {
+ for (auto&& [stat, value] : other._stats) {
+ _stats[stat] += value;
+ }
+ return *this;
+}
+
+StorageStats& WiredTigerOperationStats::operator+=(const StorageStats& other) {
+ return *this += checked_cast<const WiredTigerOperationStats&>(other);
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h
new file mode 100644
index 00000000000..581303e30b8
--- /dev/null
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h
@@ -0,0 +1,55 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <wiredtiger.h>
+
+#include "mongo/db/storage/storage_stats.h"
+
+namespace mongo {
+
+class WiredTigerOperationStats final : public StorageStats {
+public:
+ WiredTigerOperationStats() = default;
+ WiredTigerOperationStats(WT_SESSION* session);
+
+ BSONObj toBSON() const final;
+
+ std::shared_ptr<StorageStats> clone() const final;
+
+ StorageStats& operator+=(const StorageStats&) final;
+
+ WiredTigerOperationStats& operator+=(const WiredTigerOperationStats&);
+
+private:
+ std::map<int, long long> _stats;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp
new file mode 100644
index 00000000000..5151f59cb48
--- /dev/null
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp
@@ -0,0 +1,215 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h"
+#include "mongo/unittest/temp_dir.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+#define ASSERT_WT_OK(result) ASSERT_EQ(result, 0) << wiredtiger_strerror(result)
+
+class WiredTigerOperationStatsTest : public unittest::Test {
+protected:
+ void setUp() override {
+ ASSERT_WT_OK(
+ wiredtiger_open(_path.path().c_str(), nullptr, "create,statistics=(fast),", &_conn));
+ ASSERT_WT_OK(_conn->open_session(_conn, nullptr, "isolation=snapshot", &_session));
+ ASSERT_WT_OK(_session->create(
+ _session, _uri.c_str(), "type=file,key_format=q,value_format=u,log=(enabled=false)"));
+ }
+
+ void tearDown() override {
+ ASSERT_EQ(_conn->close(_conn, nullptr), 0);
+ }
+
+ /**
+ * Writes the given data using WT. Causes the bytesWritten and timeWritingMicros stats to be
+ * incremented.
+ */
+ void write(const std::string& data) {
+ ASSERT_WT_OK(_session->begin_transaction(_session, nullptr));
+
+ WT_CURSOR* cursor;
+ ASSERT_WT_OK(_session->open_cursor(_session, _uri.c_str(), nullptr, nullptr, &cursor));
+
+ cursor->set_key(cursor, _key++);
+
+ WT_ITEM item{data.data(), data.size()};
+ cursor->set_value(cursor, &item);
+
+ ASSERT_WT_OK(cursor->insert(cursor));
+ ASSERT_WT_OK(cursor->close(cursor));
+ ASSERT_WT_OK(_session->commit_transaction(_session, nullptr));
+ ASSERT_WT_OK(_session->checkpoint(_session, nullptr));
+ }
+
+ /**
+ * Reads all of the previously written data from WT. Causes the bytesRead and timeReadingMicros
+ * stats to be incremented.
+ */
+ void read() {
+ tearDown();
+ setUp();
+
+ ASSERT_WT_OK(_session->begin_transaction(_session, nullptr));
+
+ WT_CURSOR* cursor;
+ ASSERT_WT_OK(_session->open_cursor(_session, _uri.c_str(), nullptr, nullptr, &cursor));
+
+ for (int64_t i = 0; i < _key; ++i) {
+ cursor->set_key(cursor, i);
+ ASSERT_WT_OK(cursor->search(cursor));
+
+ WT_ITEM value;
+ ASSERT_WT_OK(cursor->get_value(cursor, &value));
+ }
+
+ ASSERT_WT_OK(cursor->close(cursor));
+ ASSERT_WT_OK(_session->commit_transaction(_session, nullptr));
+ }
+
+ unittest::TempDir _path{"wiredtiger_operation_stats_test"};
+ std::string _uri{"table:wiredtiger_operation_stats_test"};
+ WT_CONNECTION* _conn;
+ WT_SESSION* _session;
+ int64_t _key = 0;
+};
+
+TEST_F(WiredTigerOperationStatsTest, Empty) {
+ ASSERT_BSONOBJ_EQ(WiredTigerOperationStats{_session}.toBSON(), BSONObj{});
+}
+
+TEST_F(WiredTigerOperationStatsTest, Write) {
+ write("a");
+
+ auto statsObj = WiredTigerOperationStats{_session}.toBSON();
+
+ auto dataSection = statsObj["data"];
+ ASSERT_EQ(dataSection.type(), BSONType::Object) << statsObj;
+
+ ASSERT(dataSection["bytesWritten"]) << statsObj;
+ for (auto&& [name, value] : dataSection.Obj()) {
+ ASSERT_EQ(value.type(), BSONType::NumberLong) << statsObj;
+ ASSERT_GT(value.numberLong(), 0) << statsObj;
+ }
+}
+
+TEST_F(WiredTigerOperationStatsTest, Read) {
+ write("a");
+ read();
+
+ auto statsObj = WiredTigerOperationStats{_session}.toBSON();
+
+ auto dataSection = statsObj["data"];
+ ASSERT_EQ(dataSection.type(), BSONType::Object) << statsObj;
+
+ ASSERT(dataSection["bytesRead"]) << statsObj;
+ for (auto&& [name, value] : dataSection.Obj()) {
+ ASSERT_EQ(value.type(), BSONType::NumberLong) << statsObj;
+ ASSERT_GT(value.numberLong(), 0) << statsObj;
+ }
+}
+
+TEST_F(WiredTigerOperationStatsTest, Large) {
+ auto remaining = static_cast<int64_t>(std::numeric_limits<uint32_t>::max()) + 1;
+ while (remaining > 0) {
+ std::string data(1024 * 1024, 'a');
+ remaining -= data.size();
+ write(data);
+ }
+
+ auto statsObj = WiredTigerOperationStats{_session}.toBSON();
+ ASSERT_GT(statsObj["data"]["bytesWritten"].numberLong(), std::numeric_limits<uint32_t>::max())
+ << statsObj;
+
+ read();
+
+ statsObj = WiredTigerOperationStats{_session}.toBSON();
+ ASSERT_GT(statsObj["data"]["bytesRead"].numberLong(), std::numeric_limits<uint32_t>::max())
+ << statsObj;
+}
+
+TEST_F(WiredTigerOperationStatsTest, Add) {
+ std::vector<std::unique_ptr<WiredTigerOperationStats>> stats;
+
+ write("a");
+ stats.push_back(std::make_unique<WiredTigerOperationStats>(_session));
+
+ read();
+ stats.push_back(std::make_unique<WiredTigerOperationStats>(_session));
+
+ write("aa");
+ stats.push_back(std::make_unique<WiredTigerOperationStats>(_session));
+
+ read();
+ stats.push_back(std::make_unique<WiredTigerOperationStats>(_session));
+
+ long long bytesWritten = 0;
+ long long timeWritingMicros = 0;
+ long long bytesRead = 0;
+ long long timeReadingMicros = 0;
+
+ WiredTigerOperationStats combined;
+
+ for (auto&& op : stats) {
+ auto statsObj = op->toBSON();
+
+ bytesWritten += statsObj["data"]["bytesWritten"].numberLong();
+ timeWritingMicros += statsObj["data"]["timeWritingMicros"].numberLong();
+ bytesRead += statsObj["data"]["bytesRead"].numberLong();
+ timeReadingMicros += statsObj["data"]["timeReadingMicros"].numberLong();
+
+ combined += *op;
+ }
+
+ auto combinedObj = combined.toBSON();
+ auto dataSection = combinedObj["data"];
+ ASSERT_EQ(dataSection.type(), BSONType::Object) << combinedObj;
+ ASSERT_EQ(dataSection["bytesWritten"].numberLong(), bytesWritten) << combinedObj;
+ ASSERT_EQ(dataSection["timeWritingMicros"].numberLong(), timeWritingMicros) << combinedObj;
+ ASSERT_EQ(dataSection["bytesRead"].numberLong(), bytesRead) << combinedObj;
+ ASSERT_EQ(dataSection["timeReadingMicros"].numberLong(), timeReadingMicros) << combinedObj;
+}
+
+TEST_F(WiredTigerOperationStatsTest, Clone) {
+ write("a");
+
+ WiredTigerOperationStats stats{_session};
+ auto clone = stats.clone();
+
+ ASSERT_BSONOBJ_EQ(stats.toBSON(), clone->toBSON());
+
+ stats += *clone;
+ ASSERT_BSONOBJ_NE(stats.toBSON(), clone->toBSON());
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp
index 1f4d9d567b7..26f91072bbd 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp
@@ -37,6 +37,7 @@
#include "mongo/db/server_options.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_begin_transaction_block.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_kv_engine.h"
+#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_session_cache.h"
#include "mongo/db/storage/wiredtiger/wiredtiger_util.h"
@@ -82,107 +83,6 @@ void handleWriteContextForDebugging(WiredTigerRecoveryUnit& ru, Timestamp& ts) {
AtomicWord<std::int64_t> snapshotTooOldErrorCount{0};
-using Section = WiredTigerOperationStats::Section;
-
-std::map<int, std::pair<StringData, Section>> WiredTigerOperationStats::_statNameMap = {
- {WT_STAT_SESSION_BYTES_READ, std::make_pair("bytesRead"_sd, Section::DATA)},
- {WT_STAT_SESSION_BYTES_WRITE, std::make_pair("bytesWritten"_sd, Section::DATA)},
- {WT_STAT_SESSION_LOCK_DHANDLE_WAIT, std::make_pair("handleLock"_sd, Section::WAIT)},
- {WT_STAT_SESSION_READ_TIME, std::make_pair("timeReadingMicros"_sd, Section::DATA)},
- {WT_STAT_SESSION_WRITE_TIME, std::make_pair("timeWritingMicros"_sd, Section::DATA)},
- {WT_STAT_SESSION_LOCK_SCHEMA_WAIT, std::make_pair("schemaLock"_sd, Section::WAIT)},
- {WT_STAT_SESSION_CACHE_TIME, std::make_pair("cache"_sd, Section::WAIT)}};
-
-std::shared_ptr<StorageStats> WiredTigerOperationStats::getCopy() {
- std::shared_ptr<WiredTigerOperationStats> copy = std::make_shared<WiredTigerOperationStats>();
- *copy += *this;
- return copy;
-}
-
-void WiredTigerOperationStats::fetchStats(WT_SESSION* session,
- const std::string& uri,
- const std::string& config) {
- invariant(session);
-
- WT_CURSOR* c = nullptr;
- const char* cursorConfig = config.empty() ? nullptr : config.c_str();
- int ret = session->open_cursor(session, uri.c_str(), nullptr, cursorConfig, &c);
- uassert(ErrorCodes::CursorNotFound, "Unable to open statistics cursor", ret == 0);
-
- invariant(c);
- ON_BLOCK_EXIT([&] { c->close(c); });
-
- const char* desc;
- uint64_t value;
- int32_t key;
- while (c->next(c) == 0 && c->get_key(c, &key) == 0) {
- fassert(51035, c->get_value(c, &desc, nullptr, &value) == 0);
- _stats[key] = WiredTigerUtil::castStatisticsValue<long long>(value);
- }
-
- // Reset the statistics so that the next fetch gives the recent values.
- invariantWTOK(c->reset(c), c->session);
-}
-
-BSONObj WiredTigerOperationStats::toBSON() {
- BSONObjBuilder bob;
- std::unique_ptr<BSONObjBuilder> dataSection;
- std::unique_ptr<BSONObjBuilder> waitSection;
-
- for (auto const& stat : _stats) {
- // Find the user consumable name for this statistic.
- auto statIt = _statNameMap.find(stat.first);
-
- // Ignore the session statistic that is not reported for the slow operations.
- if (statIt == _statNameMap.end())
- continue;
-
- auto statName = statIt->second.first;
- Section subs = statIt->second.second;
- long long val = stat.second;
- // Add this statistic only if higher than zero.
- if (val > 0) {
- // Gather the statistic into its own subsection in the BSONObj.
- switch (subs) {
- case Section::DATA:
- if (!dataSection)
- dataSection = std::make_unique<BSONObjBuilder>();
-
- dataSection->append(statName, val);
- break;
- case Section::WAIT:
- if (!waitSection)
- waitSection = std::make_unique<BSONObjBuilder>();
-
- waitSection->append(statName, val);
- break;
- default:
- MONGO_UNREACHABLE;
- }
- }
- }
-
- if (dataSection)
- bob.append("data", dataSection->obj());
- if (waitSection)
- bob.append("timeWaitingMicros", waitSection->obj());
-
- return bob.obj();
-}
-
-WiredTigerOperationStats& WiredTigerOperationStats::operator+=(
- const WiredTigerOperationStats& other) {
- for (auto const& otherStat : other._stats) {
- _stats[otherStat.first] += otherStat.second;
- }
- return (*this);
-}
-
-StorageStats& WiredTigerOperationStats::operator+=(const StorageStats& other) {
- *this += checked_cast<const WiredTigerOperationStats&>(other);
- return (*this);
-}
-
WiredTigerRecoveryUnit::WiredTigerRecoveryUnit(WiredTigerSessionCache* sc)
: WiredTigerRecoveryUnit(sc, sc->getKVEngine()->getOplogManager()) {}
@@ -996,18 +896,7 @@ void WiredTigerRecoveryUnit::beginIdle() {
}
std::shared_ptr<StorageStats> WiredTigerRecoveryUnit::getOperationStatistics() const {
- std::shared_ptr<WiredTigerOperationStats> statsPtr(nullptr);
-
- if (!_session)
- return statsPtr;
-
- WT_SESSION* s = _session->getSession();
- invariant(s);
-
- statsPtr = std::make_shared<WiredTigerOperationStats>();
- statsPtr->fetchStats(s, "statistics:session", "statistics=(fast)");
-
- return statsPtr;
+ return _session ? std::make_shared<WiredTigerOperationStats>(_session->getSession()) : nullptr;
}
void WiredTigerRecoveryUnit::setCatalogConflictingTimestamp(Timestamp timestamp) {
diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h
index 26102686dbd..2d38888dc90 100644
--- a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h
+++ b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h
@@ -56,41 +56,6 @@ extern AtomicWord<std::int64_t> snapshotTooOldErrorCount;
class BSONObjBuilder;
-class WiredTigerOperationStats final : public StorageStats {
-public:
- /**
- * There are two types of statistics provided by WiredTiger engine - data and wait.
- */
- enum class Section { DATA, WAIT };
-
- BSONObj toBSON() final;
-
- StorageStats& operator+=(const StorageStats&) final;
-
- WiredTigerOperationStats& operator+=(const WiredTigerOperationStats&);
-
- /**
- * Fetches an operation's storage statistics from WiredTiger engine.
- */
- void fetchStats(WT_SESSION*, const std::string&, const std::string&);
-
- std::shared_ptr<StorageStats> getCopy() final;
-
-private:
- /**
- * Each statistic in WiredTiger has an integer key, which this map associates with a section
- * (either DATA or WAIT) and user-readable name.
- */
- static std::map<int, std::pair<StringData, Section>> _statNameMap;
-
- /**
- * Stores the value for each statistic returned by a WiredTiger cursor. Each statistic is
- * associated with an integer key, which can be mapped to a name and section using the
- * '_statNameMap'.
- */
- std::map<int, long long> _stats;
-};
-
class WiredTigerRecoveryUnit final : public RecoveryUnit {
public:
WiredTigerRecoveryUnit(WiredTigerSessionCache* sc);
diff --git a/src/mongo/db/timeseries/bucket_catalog.cpp b/src/mongo/db/timeseries/bucket_catalog.cpp
index 31cb39aeab0..854a2805d66 100644
--- a/src/mongo/db/timeseries/bucket_catalog.cpp
+++ b/src/mongo/db/timeseries/bucket_catalog.cpp
@@ -860,7 +860,11 @@ Status BucketCatalog::prepareCommit(std::shared_ptr<WriteBatch> batch) {
_useBucketInState(&stripe, stripeLock, batch->bucket().id, BucketState::kPrepared);
if (batch->finished()) {
- // Someone may have aborted it while we were waiting.
+ // Someone may have aborted it while we were waiting. Since we have the prepared batch, we
+ // should now be able to fully abort the bucket.
+ if (bucket) {
+ _abort(&stripe, stripeLock, batch, getBatchStatus());
+ }
return getBatchStatus();
} else if (!bucket) {
_abort(&stripe, stripeLock, batch, getTimeseriesBucketClearedError(batch->bucket().id));
diff --git a/src/mongo/db/timeseries/bucket_catalog_test.cpp b/src/mongo/db/timeseries/bucket_catalog_test.cpp
index 6733f4c9d89..7253111cbf2 100644
--- a/src/mongo/db/timeseries/bucket_catalog_test.cpp
+++ b/src/mongo/db/timeseries/bucket_catalog_test.cpp
@@ -911,6 +911,73 @@ TEST_F(BucketCatalogTest, CannotConcurrentlyCommitBatchesForSameBucket) {
_bucketCatalog->finish(batch2, {});
}
+TEST_F(BucketCatalogTest, AbortingBatchEnsuresBucketIsEventuallyClosed) {
+ auto batch1 = _bucketCatalog
+ ->insert(_opCtx,
+ _ns1,
+ _getCollator(_ns1),
+ _getTimeseriesOptions(_ns1),
+ BSON(_timeField << Date_t::now()),
+ BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow)
+ .getValue()
+ .batch;
+
+ auto batch2 = _bucketCatalog
+ ->insert(_makeOperationContext().second.get(),
+ _ns1,
+ _getCollator(_ns1),
+ _getTimeseriesOptions(_ns1),
+ BSON(_timeField << Date_t::now()),
+ BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow)
+ .getValue()
+ .batch;
+ auto batch3 = _bucketCatalog
+ ->insert(_makeOperationContext().second.get(),
+ _ns1,
+ _getCollator(_ns1),
+ _getTimeseriesOptions(_ns1),
+ BSON(_timeField << Date_t::now()),
+ BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow)
+ .getValue()
+ .batch;
+ ASSERT_EQ(batch1->bucket().id, batch2->bucket().id);
+ ASSERT_EQ(batch1->bucket().id, batch3->bucket().id);
+
+ ASSERT(batch1->claimCommitRights());
+ ASSERT(batch2->claimCommitRights());
+ ASSERT(batch3->claimCommitRights());
+
+ // Batch 2 will not be able to commit until batch 1 has finished.
+ ASSERT_OK(_bucketCatalog->prepareCommit(batch1));
+ auto task = Task{[&]() { ASSERT_OK(_bucketCatalog->prepareCommit(batch2)); }};
+ // Add a little extra wait to make sure prepareCommit actually gets to the blocking point.
+ stdx::this_thread::sleep_for(stdx::chrono::milliseconds(10));
+ ASSERT(task.future().valid());
+ ASSERT(stdx::future_status::timeout == task.future().wait_for(stdx::chrono::microseconds(1)))
+ << "prepareCommit finished before expected";
+
+ // If we abort the third batch, it should abort the second one too, as it isn't prepared.
+ // However, since the first batch is prepared, we can't abort it or clean up the bucket. We can
+ // then finish the first batch, which will allow the second batch to proceed. It should
+ // recognize it has been aborted and clean up the bucket.
+ _bucketCatalog->abort(batch3, Status{ErrorCodes::TimeseriesBucketCleared, "cleared"});
+ _bucketCatalog->finish(batch1, {});
+ task.future().wait();
+ ASSERT(batch2->finished());
+
+ // Make sure a new batch ends up in a new bucket.
+ auto batch4 = _bucketCatalog
+ ->insert(_opCtx,
+ _ns1,
+ _getCollator(_ns1),
+ _getTimeseriesOptions(_ns1),
+ BSON(_timeField << Date_t::now()),
+ BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow)
+ .getValue()
+ .batch;
+ ASSERT_NE(batch2->bucket().id, batch4->bucket().id);
+}
+
TEST_F(BucketCatalogTest, DuplicateNewFieldNamesAcrossConcurrentBatches) {
auto batch1 = _bucketCatalog
->insert(_opCtx,
diff --git a/src/mongo/db/transaction/transaction_api.cpp b/src/mongo/db/transaction/transaction_api.cpp
index 36635d16694..1150c5b5750 100644
--- a/src/mongo/db/transaction/transaction_api.cpp
+++ b/src/mongo/db/transaction/transaction_api.cpp
@@ -300,17 +300,32 @@ SemiFuture<BSONObj> SEPTransactionClient::runCommand(StringData dbName, BSONObj
BSONObjBuilder cmdBuilder(_behaviors->maybeModifyCommand(std::move(cmdObj)));
_hooks->runRequestHook(&cmdBuilder);
+ auto modifiedCmdObj = cmdBuilder.obj();
+ bool isAbortTxnCmd = modifiedCmdObj.firstElementFieldNameStringData() == "abortTransaction";
auto client = _serviceContext->makeClient("SEP-internal-txn-client");
AlternativeClientRegion clientRegion(client);
+
// Note that _token is only cancelled once the caller of the transaction no longer cares about
// its result, so CancelableOperationContexts only being interrupted by ErrorCodes::Interrupted
- // shouldn't impact any upstream retry logic.
- CancelableOperationContextFactory opCtxFactory(_token, _executor);
+ // shouldn't impact any upstream retry logic. If a _bestEffortAbort() is invoked, a new
+ // cancelation token must be used in constructing the opCtx for running abortTransaction. This
+ // is because an operation that has already been interrupted would cancel the parent cancelation
+ // token and using that same token to send abortTransaction would fail to send abortTransaction,
+ // leaving the transaction open longer than necessary.
+ auto opCtxFactory = isAbortTxnCmd
+ ? CancelableOperationContextFactory(CancellationToken::uncancelable(), _executor)
+ : CancelableOperationContextFactory(_token, _executor);
+
auto cancellableOpCtx = opCtxFactory.makeOperationContext(&cc());
+
+ // abortTransaction should still be interruptible on stepdown/shutdown.
+ if (isAbortTxnCmd) {
+ cancellableOpCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE();
+ }
primeInternalClient(&cc());
- auto opMsgRequest = OpMsgRequest::fromDBAndBody(dbName, cmdBuilder.obj());
+ auto opMsgRequest = OpMsgRequest::fromDBAndBody(dbName, modifiedCmdObj);
auto requestMessage = opMsgRequest.serialize();
return _behaviors->handleRequest(cancellableOpCtx.get(), requestMessage)
.then([this](DbResponse dbResponse) {
diff --git a/src/mongo/db/transaction/transaction_metrics_observer.cpp b/src/mongo/db/transaction/transaction_metrics_observer.cpp
index 0a503e3b135..7a6544a1b3b 100644
--- a/src/mongo/db/transaction/transaction_metrics_observer.cpp
+++ b/src/mongo/db/transaction/transaction_metrics_observer.cpp
@@ -188,7 +188,7 @@ void TransactionMetricsObserver::onTransactionOperation(OperationContext* opCtx,
if (storageStats) {
CurOp::get(opCtx)->debug().storageStats = storageStats;
if (!_singleTransactionStats.getOpDebug()->storageStats) {
- _singleTransactionStats.getOpDebug()->storageStats = storageStats->getCopy();
+ _singleTransactionStats.getOpDebug()->storageStats = storageStats->clone();
} else {
*_singleTransactionStats.getOpDebug()->storageStats += *storageStats;
}
diff --git a/src/mongo/db/ttl.cpp b/src/mongo/db/ttl.cpp
index 12c02d2c090..f56b3985261 100644
--- a/src/mongo/db/ttl.cpp
+++ b/src/mongo/db/ttl.cpp
@@ -320,7 +320,7 @@ void TTLMonitor::shutdown() {
void TTLMonitor::_doTTLPass() {
const ServiceContext::UniqueOperationContext opCtxPtr = cc().makeOperationContext();
OperationContext* opCtx = opCtxPtr.get();
- SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
+ SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow);
hangTTLMonitorBetweenPasses.pauseWhileSet(opCtx);