diff options
Diffstat (limited to 'src/mongo/db')
230 files changed, 10213 insertions, 3510 deletions
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index 06f719c69ac..0f59f24014c 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -2337,6 +2337,7 @@ env.Library( '$BUILD_DIR/mongo/db/change_streams_cluster_parameter', '$BUILD_DIR/mongo/db/pipeline/change_stream_expired_pre_image_remover', '$BUILD_DIR/mongo/db/query/ce/query_ce_histogram', + '$BUILD_DIR/mongo/db/s/query_analysis_writer', '$BUILD_DIR/mongo/db/set_change_stream_state_coordinator', '$BUILD_DIR/mongo/idl/cluster_server_parameter', '$BUILD_DIR/mongo/idl/cluster_server_parameter_op_observer', diff --git a/src/mongo/db/catalog/SConscript b/src/mongo/db/catalog/SConscript index 7d495903c0f..302d3782c72 100644 --- a/src/mongo/db/catalog/SConscript +++ b/src/mongo/db/catalog/SConscript @@ -508,6 +508,7 @@ env.Library( env.Library( target='catalog_helpers', source=[ + "$BUILD_DIR/mongo/db/commands/external_data_source_scope_guard.cpp", 'capped_utils.cpp', 'coll_mod_index.cpp', 'coll_mod.cpp', diff --git a/src/mongo/db/catalog/collection.h b/src/mongo/db/catalog/collection.h index 3fac44cff7a..76905a8fe6f 100644 --- a/src/mongo/db/catalog/collection.h +++ b/src/mongo/db/catalog/collection.h @@ -67,6 +67,9 @@ struct CollectionUpdateArgs { std::vector<StmtId> stmtIds = {kUninitializedStmtId}; + // The unique sample id for this update if it has been chosen for sampling. + boost::optional<UUID> sampleId; + // The document before modifiers were applied. boost::optional<BSONObj> preImageDoc; diff --git a/src/mongo/db/catalog/collection_catalog.cpp b/src/mongo/db/catalog/collection_catalog.cpp index 1d644b93e69..735a4b897b0 100644 --- a/src/mongo/db/catalog/collection_catalog.cpp +++ b/src/mongo/db/catalog/collection_catalog.cpp @@ -52,6 +52,13 @@ namespace mongo { namespace { +// Sentinel id for marking a catalogId mapping range as unknown. Must use an invalid RecordId. +static RecordId kUnknownRangeMarkerId = RecordId::minLong(); +// Maximum number of entries in catalogId mapping when inserting catalogId missing at timestamp. +// Used to avoid quadratic behavior when inserting entries at the beginning. When threshold is +// reached we will fall back to more durable catalog scans. +static constexpr int kMaxCatalogIdMappingLengthForMissingInsert = 1000; + struct LatestCollectionCatalog { std::shared_ptr<CollectionCatalog> catalog = std::make_shared<CollectionCatalog>(); }; @@ -1288,7 +1295,11 @@ CollectionCatalog::CatalogIdLookup CollectionCatalog::lookupCatalogIdByNSS( // iterator to get the last entry where the time is less or equal. auto catalogId = (--rangeIt)->id; if (catalogId) { - return {*catalogId, CatalogIdLookup::NamespaceExistence::kExists}; + if (*catalogId != kUnknownRangeMarkerId) { + return {*catalogId, CatalogIdLookup::NamespaceExistence::kExists}; + } else { + return {RecordId{}, CatalogIdLookup::NamespaceExistence::kUnknown}; + } } return {RecordId{}, CatalogIdLookup::NamespaceExistence::kNotExists}; } @@ -1615,12 +1626,15 @@ std::shared_ptr<Collection> CollectionCatalog::deregisterCollection( // TODO SERVER-68674: Remove feature flag check. if (feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV() && isDropPending) { - auto ident = coll->getSharedIdent()->getIdent(); - LOGV2_DEBUG(6825300, 1, "Registering drop pending collection ident", "ident"_attr = ident); + if (auto sharedIdent = coll->getSharedIdent(); sharedIdent) { + auto ident = sharedIdent->getIdent(); + LOGV2_DEBUG( + 6825300, 1, "Registering drop pending collection ident", "ident"_attr = ident); - auto it = _dropPendingCollection.find(ident); - invariant(it == _dropPendingCollection.end()); - _dropPendingCollection[ident] = coll; + auto it = _dropPendingCollection.find(ident); + invariant(it == _dropPendingCollection.end()); + _dropPendingCollection[ident] = coll; + } } _orderedCollections.erase(dbIdPair); @@ -1735,7 +1749,8 @@ void CollectionCatalog::_pushCatalogIdForNSS(const NamespaceString& nss, return; } - // Re-write latest entry if timestamp match (multiple changes occured in this transaction) + // An entry could exist already if concurrent writes are performed, keep the latest change in + // that case. if (!ids.empty() && ids.back().ts == *ts) { ids.back().id = catalogId; return; @@ -1771,8 +1786,8 @@ void CollectionCatalog::_pushCatalogIdForRename(const NamespaceString& from, auto& fromIds = _catalogIds.at(from); invariant(!fromIds.empty()); - // Re-write latest entry if timestamp match (multiple changes occured in this transaction), - // otherwise push at end + // An entry could exist already if concurrent writes are performed, keep the latest change in + // that case. if (!toIds.empty() && toIds.back().ts == *ts) { toIds.back().id = fromIds.back().id; } else { @@ -1792,6 +1807,83 @@ void CollectionCatalog::_pushCatalogIdForRename(const NamespaceString& from, } } +void CollectionCatalog::_insertCatalogIdForNSSAfterScan(const NamespaceString& nss, + boost::optional<RecordId> catalogId, + Timestamp ts) { + // TODO SERVER-68674: Remove feature flag check. + if (!feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV()) { + // No-op. + return; + } + + auto& ids = _catalogIds[nss]; + + // Binary search for to the entry with same or larger timestamp + auto it = + std::lower_bound(ids.begin(), ids.end(), ts, [](const auto& entry, const Timestamp& ts) { + return entry.ts < ts; + }); + + // The logic of what we need to do differs whether we are inserting a valid catalogId or not. + if (catalogId) { + if (it != ids.end()) { + // An entry could exist already if concurrent writes are performed, keep the latest + // change in that case. + if (it->ts == ts) { + it->id = catalogId; + return; + } + + // If next element has same catalogId, we can adjust its timestamp to cover a longer + // range + if (it->id == catalogId) { + it->ts = ts; + _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids); + return; + } + } + + // Otherwise insert new entry at timestamp + ids.insert(it, TimestampedCatalogId{catalogId, ts}); + _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids); + return; + } + + // Avoid inserting missing mapping when the list has grown past the threshold. Will cause the + // system to fall back to scanning the durable catalog. + if (ids.size() >= kMaxCatalogIdMappingLengthForMissingInsert) { + return; + } + + if (it != ids.end() && it->ts == ts) { + // An entry could exist already if concurrent writes are performed, keep the latest change + // in that case. + it->id = boost::none; + } else { + // Otherwise insert new entry + it = ids.insert(it, TimestampedCatalogId{boost::none, ts}); + } + + // The iterator is positioned on the added/modified element above, reposition it to the next + // entry + ++it; + + // We don't want to assume that the namespace remains not existing until the next entry, as + // there can be times where the namespace actually does exist. To make sure we trigger the + // scanning of the durable catalog in this range we will insert a bogus entry using an invalid + // RecordId at the next timestamp. This will treat the range forward as unknown. + auto nextTs = ts + 1; + + // If the next entry is on the next timestamp already, we can skip adding the bogus entry. If + // this function is called for a previously unknown namespace, we may not have any future valid + // entries and the iterator would be positioned at and at this point. + if (it == ids.end() || it->ts != nextTs) { + ids.insert(it, TimestampedCatalogId{kUnknownRangeMarkerId, nextTs}); + } + + _markNamespaceForCatalogIdCleanupIfNeeded(nss, ids); +} + void CollectionCatalog::_markNamespaceForCatalogIdCleanupIfNeeded( const NamespaceString& nss, const std::vector<TimestampedCatalogId>& ids) { diff --git a/src/mongo/db/catalog/collection_catalog.h b/src/mongo/db/catalog/collection_catalog.h index d3d3446f1be..0b2561f0996 100644 --- a/src/mongo/db/catalog/collection_catalog.h +++ b/src/mongo/db/catalog/collection_catalog.h @@ -736,6 +736,15 @@ private: const NamespaceString& to, boost::optional<Timestamp> ts); + // TODO SERVER-70150: Make private again +public: + // Inserts a catalogId for namespace at given Timestamp. Used after scanning the durable catalog + // for a correct mapping at the given timestamp. + void _insertCatalogIdForNSSAfterScan(const NamespaceString& nss, + boost::optional<RecordId> catalogId, + Timestamp ts); + +private: // Helper to calculate if a namespace needs to be marked for cleanup for a set of timestamped // catalogIds void _markNamespaceForCatalogIdCleanupIfNeeded(const NamespaceString& nss, diff --git a/src/mongo/db/catalog/collection_catalog_test.cpp b/src/mongo/db/catalog/collection_catalog_test.cpp index 6a554f8fe59..8710d807f73 100644 --- a/src/mongo/db/catalog/collection_catalog_test.cpp +++ b/src/mongo/db/catalog/collection_catalog_test.cpp @@ -2114,6 +2114,170 @@ TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingRollback) { CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists); } +TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingInsert) { + RAIIServerParameterControllerForTest featureFlagController( + "featureFlagPointInTimeCatalogLookups", true); + + NamespaceString nss("a.b"); + + // Create a collection on the namespace + createCollection(opCtx.get(), nss, Timestamp(1, 10)); + dropCollection(opCtx.get(), nss, Timestamp(1, 20)); + createCollection(opCtx.get(), nss, Timestamp(1, 30)); + + auto rid1 = catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 10)).id; + auto rid2 = catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id; + + // Simulate startup where we have a range [oldest, stable] by creating and dropping collections + // and then advancing the oldest timestamp and then reading behind it. + CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) { + catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 40)); + }); + + // Confirm that the mappings have been cleaned up + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); + + // TODO SERVER-70150: Use openCollection + CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) { + catalog._insertCatalogIdForNSSAfterScan(nss, rid1, Timestamp(1, 17)); + }); + + // Lookups before the inserted timestamp is still unknown + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 11)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); + + // Lookups at or after the inserted timestamp is found, even if they don't match with WT + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).id, rid1); + // The entry at Timestamp(1, 30) is unaffected + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2); + + // TODO SERVER-70150: Use openCollection + CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) { + catalog._insertCatalogIdForNSSAfterScan(nss, rid1, Timestamp(1, 12)); + }); + + // We should now have extended the range from Timestamp(1, 17) to Timestamp(1, 12) + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 12)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 12)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 16)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 16)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).id, rid1); + // The entry at Timestamp(1, 30) is unaffected + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2); + + // TODO SERVER-70150: Use openCollection + CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) { + catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 25)); + }); + + // Check the entries, most didn't change + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).id, rid1); + // At Timestamp(1, 25) we now return kNotExists + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists); + // But next timestamp returns unknown + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 26)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); + // The entry at Timestamp(1, 30) is unaffected + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2); + + // TODO SERVER-70150: Use openCollection + CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) { + catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 26)); + }); + + // We should not have re-written the existing entry at Timestamp(1, 26) + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 17)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 19)).id, rid1); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 22)).id, rid1); + // At Timestamp(1, 25) we now return kNotExists + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 25)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists); + // But next timestamp returns unknown + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 26)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 27)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); + // The entry at Timestamp(1, 30) is unaffected + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kExists); + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 30)).id, rid2); + + // Clean up, check so we are back to the original state + CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) { + catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 41)); + }); + + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); +} + +TEST_F(CollectionCatalogTimestampTest, CatalogIdMappingInsertUnknown) { + RAIIServerParameterControllerForTest featureFlagController( + "featureFlagPointInTimeCatalogLookups", true); + + NamespaceString nss("a.b"); + + // Simulate startup where we have a range [oldest, stable] by advancing the oldest timestamp and + // then reading behind it. + CollectionCatalog::write(opCtx.get(), [](CollectionCatalog& catalog) { + catalog.cleanupForOldestTimestampAdvanced(Timestamp(1, 40)); + }); + + // Reading before the oldest is unknown + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kUnknown); + + // Try to instantiate a non existing collection at this timestamp. + // TODO SERVER-70150: Use openCollection + CollectionCatalog::write(opCtx.get(), [&](CollectionCatalog& catalog) { + catalog._insertCatalogIdForNSSAfterScan(nss, boost::none, Timestamp(1, 15)); + }); + + // Lookup should now be not existing + ASSERT_EQ(catalog()->lookupCatalogIdByNSS(nss, Timestamp(1, 15)).result, + CollectionCatalog::CatalogIdLookup::NamespaceExistence::kNotExists); +} + TEST_F(CollectionCatalogTimestampTest, CollectionLifetimeTiedToStorageTransactionLifetime) { RAIIServerParameterControllerForTest featureFlagController( "featureFlagPointInTimeCatalogLookups", true); diff --git a/src/mongo/db/catalog/collection_impl.cpp b/src/mongo/db/catalog/collection_impl.cpp index b4f3078a3c9..c028252f3b1 100644 --- a/src/mongo/db/catalog/collection_impl.cpp +++ b/src/mongo/db/catalog/collection_impl.cpp @@ -462,6 +462,7 @@ Status CollectionImpl::initFromExisting(OperationContext* opCtx, // Update the idents for the newly initialized indexes. for (const auto& sharedIdent : sharedIdents) { auto desc = getIndexCatalog()->findIndexByName(opCtx, sharedIdent.first); + invariant(desc); auto entry = getIndexCatalog()->getEntryShared(desc); entry->setIdent(sharedIdent.second); } diff --git a/src/mongo/db/catalog/create_collection.cpp b/src/mongo/db/catalog/create_collection.cpp index 79fecf84cd7..69d71d25468 100644 --- a/src/mongo/db/catalog/create_collection.cpp +++ b/src/mongo/db/catalog/create_collection.cpp @@ -409,7 +409,7 @@ Status _createTimeseries(OperationContext* opCtx, AutoGetCollection::Options{}.viewMode(auto_get_collection::ViewMode::kViewsPermitted)); Lock::CollectionLock systemDotViewsLock( opCtx, - NamespaceString(ns.db(), NamespaceString::kSystemDotViewsCollectionName), + NamespaceString(ns.dbName(), NamespaceString::kSystemDotViewsCollectionName), MODE_X); auto db = autoColl.ensureDbExists(opCtx); diff --git a/src/mongo/db/catalog/create_collection.h b/src/mongo/db/catalog/create_collection.h index 08754d89daa..655644c4e36 100644 --- a/src/mongo/db/catalog/create_collection.h +++ b/src/mongo/db/catalog/create_collection.h @@ -27,6 +27,8 @@ * it in the license file. */ +#pragma once + #include <boost/optional.hpp> #include <string> diff --git a/src/mongo/db/catalog/create_collection_test.cpp b/src/mongo/db/catalog/create_collection_test.cpp index 73c724585f3..97d1384f40f 100644 --- a/src/mongo/db/catalog/create_collection_test.cpp +++ b/src/mongo/db/catalog/create_collection_test.cpp @@ -324,10 +324,8 @@ TEST_F(CreateCollectionTest, ValidationDisabledForTemporaryReshardingCollection) ASSERT_OK(status); } -static const std::string kValidUrl1 = - ExternalDataSourceMetadata::kDefaultFileUrlPrefix + "named_pipe1"; -static const std::string kValidUrl2 = - ExternalDataSourceMetadata::kDefaultFileUrlPrefix + "named_pipe2"; +const auto kValidUrl1 = ExternalDataSourceMetadata::kUrlProtocolFile + "named_pipe1"s; +const auto kValidUrl2 = ExternalDataSourceMetadata::kUrlProtocolFile + "named_pipe2"s; TEST_F(CreateVirtualCollectionTest, VirtualCollectionOptionsWithOneSource) { NamespaceString vcollNss("myDb", "vcoll.name"); @@ -344,8 +342,7 @@ TEST_F(CreateVirtualCollectionTest, VirtualCollectionOptionsWithOneSource) { auto vcollOpts = getVirtualCollectionOptions(opCtx.get(), vcollNss); ASSERT_EQ(vcollOpts.dataSources.size(), 1); - ASSERT_EQ(ExternalDataSourceMetadata::kUrlProtocolFile + vcollOpts.dataSources[0].url, - kValidUrl1); + ASSERT_EQ(vcollOpts.dataSources[0].url, kValidUrl1); ASSERT_EQ(stdx::to_underlying(vcollOpts.dataSources[0].storageType), stdx::to_underlying(StorageTypeEnum::pipe)); ASSERT_EQ(stdx::to_underlying(vcollOpts.dataSources[0].fileType), @@ -389,7 +386,7 @@ TEST_F(CreateVirtualCollectionTest, InvalidVirtualCollectionOptions) { { bool exceptionOccurred = false; VirtualCollectionOptions reqVcollOpts; - constexpr auto kInvalidUrl = "file:///abc/named_pipe"; + constexpr auto kInvalidUrl = "fff://abc/named_pipe"_sd; try { reqVcollOpts.dataSources.emplace_back( kInvalidUrl, StorageTypeEnum::pipe, FileTypeEnum::bson); diff --git a/src/mongo/db/catalog/database_impl.cpp b/src/mongo/db/catalog/database_impl.cpp index 618ecedc65a..0a3c9754e16 100644 --- a/src/mongo/db/catalog/database_impl.cpp +++ b/src/mongo/db/catalog/database_impl.cpp @@ -573,20 +573,22 @@ Status DatabaseImpl::_finishDropCollection(OperationContext* opCtx, "namespace"_attr = nss, "uuid"_attr = uuid); - auto status = catalog::dropCollection( - opCtx, collection->ns(), collection->getCatalogId(), collection->getSharedIdent()); - if (!status.isOK()) - return status; + // A virtual collection does not have a durable catalog entry. + if (auto sharedIdent = collection->getSharedIdent()) { + auto status = catalog::dropCollection( + opCtx, collection->ns(), collection->getCatalogId(), sharedIdent); + if (!status.isOK()) + return status; - opCtx->recoveryUnit()->onCommit( - [opCtx, nss, uuid, ident = collection->getSharedIdent()->getIdent()]( - boost::optional<Timestamp> commitTime) { + opCtx->recoveryUnit()->onCommit([opCtx, nss, uuid, ident = sharedIdent->getIdent()]( + boost::optional<Timestamp> commitTime) { if (!commitTime) { return; } HistoricalIdentTracker::get(opCtx).recordDrop(ident, nss, uuid, commitTime.value()); }); + } CollectionCatalog::get(opCtx)->dropCollection( opCtx, collection, opCtx->getServiceContext()->getStorageEngine()->supportsPendingDrops()); diff --git a/src/mongo/db/catalog/virtual_collection_impl.cpp b/src/mongo/db/catalog/virtual_collection_impl.cpp index 1c7e2f4e341..bd5c80d50c5 100644 --- a/src/mongo/db/catalog/virtual_collection_impl.cpp +++ b/src/mongo/db/catalog/virtual_collection_impl.cpp @@ -31,6 +31,7 @@ #include "mongo/db/catalog/collection_impl.h" #include "mongo/db/catalog/collection_options.h" +#include "mongo/db/catalog/index_catalog_impl.h" #include "mongo/db/operation_context.h" #include "mongo/db/storage/external_record_store.h" @@ -43,8 +44,9 @@ VirtualCollectionImpl::VirtualCollectionImpl(OperationContext* opCtx, std::unique_ptr<ExternalRecordStore> recordStore) : _nss(nss), _options(options), - _recordStore(std::move(recordStore)), - _collator(CollectionImpl::parseCollation(opCtx, nss, options.collation)) { + _shared(std::make_shared<SharedState>( + std::move(recordStore), CollectionImpl::parseCollation(opCtx, nss, options.collation))), + _indexCatalog(std::make_unique<IndexCatalogImpl>()) { tassert(6968503, "Cannot create _id index for a virtual collection", options.autoIndexId == CollectionOptions::NO && options.idIndex.isEmpty()); @@ -57,16 +59,4 @@ std::shared_ptr<Collection> VirtualCollectionImpl::make(OperationContext* opCtx, return std::make_shared<VirtualCollectionImpl>( opCtx, nss, options, std::make_unique<ExternalRecordStore>(nss.ns(), vopts)); } - -std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> VirtualCollectionImpl::makePlanExecutor( - OperationContext* opCtx, - const CollectionPtr& yieldableCollection, - PlanYieldPolicy::YieldPolicy yieldPolicy, - ScanDirection scanDirection, - const boost::optional<RecordId>& resumeAfterRecordId) const { - // TODO SERVER-69683 Implement this function when implementing MultiBsonStreamCursor since - // we can scan when it's done. We don't support scan yet. - unimplementedTasserted(); - return nullptr; -} } // namespace mongo diff --git a/src/mongo/db/catalog/virtual_collection_impl.h b/src/mongo/db/catalog/virtual_collection_impl.h index a490f4f9ee1..499e6add895 100644 --- a/src/mongo/db/catalog/virtual_collection_impl.h +++ b/src/mongo/db/catalog/virtual_collection_impl.h @@ -50,19 +50,20 @@ public: const CollectionOptions& options, std::unique_ptr<ExternalRecordStore> recordStore); + VirtualCollectionImpl(const VirtualCollectionImpl&) = default; + ~VirtualCollectionImpl() = default; const VirtualCollectionOptions& getVirtualCollectionOptions() const { - return _recordStore->getOptions(); + return _shared->_recordStore->getOptions(); } std::shared_ptr<Collection> clone() const final { - unimplementedTasserted(); - return nullptr; + return std::make_shared<VirtualCollectionImpl>(*this); } SharedCollectionDecorations* getSharedDecorations() const final { - return nullptr; + return &_shared->_sharedDecorations; } Status initFromExisting(OperationContext* opCtx, @@ -96,19 +97,21 @@ public: } const IndexCatalog* getIndexCatalog() const final { - return nullptr; + return _indexCatalog.get(); } IndexCatalog* getIndexCatalog() final { - return nullptr; + return _indexCatalog.get(); } RecordStore* getRecordStore() const final { - return _recordStore.get(); + return _shared->_recordStore.get(); } + // A virtual collection can't have an 'ident' because 'ident' is an identifier to a WT table + // which a virtual colelction does not have. So returns nullptr. std::shared_ptr<Ident> getSharedIdent() const final { - return _recordStore->getSharedIdent(); + return nullptr; } void setIdent(std::shared_ptr<Ident> newIdent) final { @@ -149,7 +152,7 @@ public: std::unique_ptr<SeekableRecordCursor> getCursor(OperationContext* opCtx, bool forward = true) const final { - return _recordStore->getCursor(opCtx, forward); + return _shared->_recordStore->getCursor(opCtx, forward); } void deleteDocument( @@ -368,35 +371,26 @@ public: } int getTotalIndexCount() const final { - unimplementedTasserted(); return 0; } int getCompletedIndexCount() const final { - unimplementedTasserted(); return 0; } BSONObj getIndexSpec(StringData indexName) const final { - unimplementedTasserted(); return BSONObj(); } - void getAllIndexes(std::vector<std::string>* names) const final { - unimplementedTasserted(); - } + void getAllIndexes(std::vector<std::string>* names) const final {} - void getReadyIndexes(std::vector<std::string>* names) const final { - unimplementedTasserted(); - } + void getReadyIndexes(std::vector<std::string>* names) const final {} bool isIndexPresent(StringData indexName) const final { - unimplementedTasserted(); return false; } bool isIndexReady(StringData indexName) const final { - unimplementedTasserted(); return false; } @@ -428,15 +422,15 @@ public: } long long numRecords(OperationContext* opCtx) const final { - return _recordStore->numRecords(opCtx); + return _shared->_recordStore->numRecords(opCtx); } long long dataSize(OperationContext* opCtx) const final { - return _recordStore->dataSize(opCtx); + return _shared->_recordStore->dataSize(opCtx); } bool isEmpty(OperationContext* opCtx) const final { - return _recordStore->dataSize(opCtx) == 0LL; + return _shared->_recordStore->dataSize(opCtx) == 0LL; } inline int averageObjectSize(OperationContext* opCtx) const { @@ -478,7 +472,7 @@ public: * Collection is destroyed. */ const CollatorInterface* getDefaultCollator() const final { - return _collator.get(); + return _shared->_collator.get(); } const CollectionOptions& getCollectionOptions() const final { @@ -491,12 +485,17 @@ public: return Status(ErrorCodes::UnknownError, "unknown"); } + // This method is used in context of rollback and index build which are not supported for a + // virtual collection. std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> makePlanExecutor( OperationContext* opCtx, const CollectionPtr& yieldableCollection, PlanYieldPolicy::YieldPolicy yieldPolicy, ScanDirection scanDirection, - const boost::optional<RecordId>& resumeAfterRecordId) const final; + const boost::optional<RecordId>& resumeAfterRecordId) const final { + unimplementedTasserted(); + return nullptr; + } void indexBuildSuccess(OperationContext* opCtx, IndexCatalogEntry* index) final { unimplementedTasserted(); @@ -509,13 +508,32 @@ private: MONGO_UNIMPLEMENTED_TASSERT(6968504); } + struct SharedState { + SharedState(std::unique_ptr<ExternalRecordStore> recordStore, + std::unique_ptr<CollatorInterface> collator) + : _recordStore(std::move(recordStore)), _collator(std::move(collator)) {} + + ~SharedState() = default; + + std::unique_ptr<ExternalRecordStore> _recordStore; + + // This object is decorable and decorated with unversioned data related to the collection. + // Not associated with any particular Collection instance for the collection, but shared + // across all instances for the same collection. This is a vehicle for users of a collection + // to cache unversioned state for a collection that is accessible across all of the + // Collection instances. + SharedCollectionDecorations _sharedDecorations; + + // The default collation which is applied to operations and indices which have no collation + // of their own. The collection's validator will respect this collation. If null, the + // default collation is simple binary compare. + std::unique_ptr<CollatorInterface> _collator; + }; + NamespaceString _nss; CollectionOptions _options; - std::unique_ptr<ExternalRecordStore> _recordStore; - // The default collation which is applied to operations and indices which have no collation - // of their own. The collection's validator will respect this collation. If null, the - // default collation is simple binary compare. - std::unique_ptr<CollatorInterface> _collator; + std::shared_ptr<SharedState> _shared; + clonable_ptr<IndexCatalog> _indexCatalog; }; } // namespace mongo diff --git a/src/mongo/db/catalog/virtual_collection_options.h b/src/mongo/db/catalog/virtual_collection_options.h index 433976ab3e5..cd8a3a01936 100644 --- a/src/mongo/db/catalog/virtual_collection_options.h +++ b/src/mongo/db/catalog/virtual_collection_options.h @@ -29,46 +29,33 @@ #pragma once -#include <cstdint> -#include <fmt/format.h> #include <string> #include <vector> #include "mongo/base/string_data.h" -#include "mongo/db/pipeline/aggregate_command_gen.h" #include "mongo/db/pipeline/external_data_source_option_gen.h" #include "mongo/util/assert_util.h" namespace mongo { - /** * Metadata for external data source. */ struct ExternalDataSourceMetadata { static constexpr auto kUrlProtocolFile = "file://"_sd; -#ifndef _WIN32 - static constexpr auto kDefaultFileUrlPrefix = "file:///tmp/"_sd; -#else - static constexpr auto kDefaultFileUrlPrefix = "file:////./pipe/"_sd; -#endif - ExternalDataSourceMetadata(const std::string& url, - StorageTypeEnum storageType, - FileTypeEnum fileType) - : storageType(storageType), fileType(fileType) { - using namespace fmt::literals; + ExternalDataSourceMetadata(StringData urlStr, + StorageTypeEnum storageTypeEnum, + FileTypeEnum fileTypeEnum) + : url(urlStr), storageType(storageTypeEnum), fileType(fileTypeEnum) { uassert(6968500, - "File url must start with {}"_format(kDefaultFileUrlPrefix), - url.find(kDefaultFileUrlPrefix.toString()) == 0); + "File url must start with {}"_format(kUrlProtocolFile), + urlStr.startsWith(kUrlProtocolFile)); uassert(6968501, "Storage type must be 'pipe'", storageType == StorageTypeEnum::pipe); uassert(6968502, "File type must be 'bson'", fileType == FileTypeEnum::bson); - - // Strip off the protocol prefix. - this->url = url.substr(kUrlProtocolFile.size()); } ExternalDataSourceMetadata(const ExternalDataSourceInfo& dataSourceInfo) - : ExternalDataSourceMetadata(dataSourceInfo.getUrl().toString(), + : ExternalDataSourceMetadata(dataSourceInfo.getUrl(), dataSourceInfo.getStorageType(), dataSourceInfo.getFileType()) {} diff --git a/src/mongo/db/change_stream_pre_images_collection_manager.cpp b/src/mongo/db/change_stream_pre_images_collection_manager.cpp index 103ce9a2fb1..7d606635708 100644 --- a/src/mongo/db/change_stream_pre_images_collection_manager.cpp +++ b/src/mongo/db/change_stream_pre_images_collection_manager.cpp @@ -77,6 +77,7 @@ boost::optional<std::int64_t> getExpireAfterSecondsFromChangeStreamOptions( // Returns pre-images expiry time in milliseconds since the epoch time if configured, boost::none // otherwise. boost::optional<Date_t> getPreImageExpirationTime(OperationContext* opCtx, Date_t currentTime) { + invariant(!change_stream_serverless_helpers::isChangeCollectionsModeActive()); boost::optional<std::int64_t> expireAfterSeconds = boost::none; // Get the expiration time directly from the change stream manager. @@ -140,7 +141,6 @@ void ChangeStreamPreImagesCollectionManager::insertPreImage(OperationContext* op << preImage.getId().getApplyOpsIndex(), preImage.getId().getApplyOpsIndex() >= 0); - // TODO SERVER-66642 Consider using internal test-tenant id if applicable. const auto preImagesCollectionNamespace = NamespaceString::makePreImageCollectionNSS( change_stream_serverless_helpers::resolveTenantId(tenantId)); @@ -231,49 +231,15 @@ boost::optional<UUID> findNextCollectionUUID(OperationContext* opCtx, * | applyIndex: 0 | | applyIndex: 0 | | applyIndex: 0 | | applyIndex: 1 | * +-------------------+ +-------------------+ +-------------------+ +-------------------+ */ -size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx, - Date_t currentTimeForTimeBasedExpiration) { - // Acquire intent-exclusive lock on the pre-images collection. Early exit if the collection - // doesn't exist. - // TODO SERVER-66642 Account for multitenancy. - AutoGetCollection autoColl( - opCtx, NamespaceString::makePreImageCollectionNSS(boost::none), MODE_IX); - const auto& preImagesColl = autoColl.getCollection(); - if (!preImagesColl) { - return 0; - } - - // Do not run the job on secondaries. - if (!repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase( - opCtx, NamespaceString::kConfigDb)) { - return 0; - } - - // Get the timestamp of the earliest oplog entry. - const auto currentEarliestOplogEntryTs = - repl::StorageInterface::get(opCtx->getServiceContext())->getEarliestOplogTimestamp(opCtx); - - const bool isBatchedRemoval = gBatchedExpiredChangeStreamPreImageRemoval.load(); +size_t _deleteExpiredChangeStreamPreImagesCommon(OperationContext* opCtx, + const CollectionPtr& preImageColl, + const MatchExpression* filterPtr, + Timestamp maxRecordIdTimestamp) { size_t numberOfRemovals = 0; - const auto preImageExpirationTime = change_stream_pre_image_helpers::getPreImageExpirationTime( - opCtx, currentTimeForTimeBasedExpiration); - - // Configure the filter for the case when expiration parameter is set. - OrMatchExpression filter; - const MatchExpression* filterPtr = nullptr; - if (preImageExpirationTime) { - filter.add( - std::make_unique<LTMatchExpression>("_id.ts"_sd, Value(currentEarliestOplogEntryTs))); - filter.add(std::make_unique<LTEMatchExpression>("operationTime"_sd, - Value(*preImageExpirationTime))); - filterPtr = &filter; - } - const bool shouldReturnEofOnFilterMismatch = preImageExpirationTime.has_value(); - - // TODO SERVER-66642 Account for multitenancy. + const bool isBatchedRemoval = gBatchedExpiredChangeStreamPreImageRemoval.load(); boost::optional<UUID> currentCollectionUUID = boost::none; while ((currentCollectionUUID = - findNextCollectionUUID(opCtx, &preImagesColl, currentCollectionUUID))) { + findNextCollectionUUID(opCtx, &preImageColl, currentCollectionUUID))) { writeConflictRetry( opCtx, "ChangeStreamExpiredPreImagesRemover", @@ -288,22 +254,14 @@ size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx, } RecordIdBound minRecordId( toRecordId(ChangeStreamPreImageId(*currentCollectionUUID, Timestamp(), 0))); - - // If the expiration parameter is set, the 'maxRecord' is set to the maximum - // RecordId for this collection. Whether the pre-image has to be deleted will be - // determined by the filtering MatchExpression. - // - // If the expiration parameter is not set, then the last expired pre-image timestamp - // equals to one increment before the 'currentEarliestOplogEntryTs'. - RecordIdBound maxRecordId = RecordIdBound(toRecordId(ChangeStreamPreImageId( - *currentCollectionUUID, - preImageExpirationTime ? Timestamp::max() - : Timestamp(currentEarliestOplogEntryTs.asULL() - 1), - std::numeric_limits<int64_t>::max()))); + RecordIdBound maxRecordId = RecordIdBound( + toRecordId(ChangeStreamPreImageId(*currentCollectionUUID, + maxRecordIdTimestamp, + std::numeric_limits<int64_t>::max()))); auto exec = InternalPlanner::deleteWithCollectionScan( opCtx, - &preImagesColl, + &preImageColl, std::move(params), PlanYieldPolicy::YieldPolicy::YIELD_AUTO, InternalPlanner::Direction::FORWARD, @@ -312,12 +270,83 @@ size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx, CollectionScanParams::ScanBoundInclusion::kIncludeBothStartAndEndRecords, std::move(batchedDeleteParams), filterPtr, - shouldReturnEofOnFilterMismatch); + filterPtr != nullptr); numberOfRemovals += exec->executeDelete(); }); } return numberOfRemovals; } + +size_t deleteExpiredChangeStreamPreImages(OperationContext* opCtx, + Date_t currentTimeForTimeBasedExpiration) { + // Acquire intent-exclusive lock on the change collection. + AutoGetCollection preImageColl( + opCtx, NamespaceString::makePreImageCollectionNSS(boost::none), MODE_IX); + + // Early exit if the collection doesn't exist or running on a secondary. + if (!preImageColl || + !repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase( + opCtx, NamespaceString::kConfigDb)) { + return 0; + } + + // Get the timestamp of the earliest oplog entry. + const auto currentEarliestOplogEntryTs = + repl::StorageInterface::get(opCtx->getServiceContext())->getEarliestOplogTimestamp(opCtx); + + const auto preImageExpirationTime = change_stream_pre_image_helpers::getPreImageExpirationTime( + opCtx, currentTimeForTimeBasedExpiration); + + // Configure the filter for the case when expiration parameter is set. + if (preImageExpirationTime) { + OrMatchExpression filter; + filter.add( + std::make_unique<LTMatchExpression>("_id.ts"_sd, Value(currentEarliestOplogEntryTs))); + filter.add(std::make_unique<LTEMatchExpression>("operationTime"_sd, + Value(*preImageExpirationTime))); + // If 'preImageExpirationTime' is set, set 'maxRecordIdTimestamp' is set to the maximum + // RecordId for this collection. Whether the pre-image has to be deleted will be determined + // by the 'filter' parameter. + return _deleteExpiredChangeStreamPreImagesCommon( + opCtx, *preImageColl, &filter, Timestamp::max() /* maxRecordIdTimestamp */); + } + + // 'preImageExpirationTime' is not set, so the last expired pre-image timestamp is less than + // 'currentEarliestOplogEntryTs'. + return _deleteExpiredChangeStreamPreImagesCommon( + opCtx, + *preImageColl, + nullptr /* filterPtr */, + Timestamp(currentEarliestOplogEntryTs.asULL() - 1) /* maxRecordIdTimestamp */); +} + +size_t deleteExpiredChangeStreamPreImagesForTenants(OperationContext* opCtx, + const TenantId& tenantId, + Date_t currentTimeForTimeBasedExpiration) { + + // Acquire intent-exclusive lock on the change collection. + AutoGetCollection preImageColl(opCtx, + NamespaceString::makePreImageCollectionNSS( + change_stream_serverless_helpers::resolveTenantId(tenantId)), + MODE_IX); + + // Early exit if the collection doesn't exist or running on a secondary. + if (!preImageColl || + !repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase( + opCtx, NamespaceString::kConfigDb)) { + return 0; + } + + auto expiredAfterSeconds = change_stream_serverless_helpers::getExpireAfterSeconds(tenantId); + LTEMatchExpression filter{ + "operationTime"_sd, + Value(currentTimeForTimeBasedExpiration - Seconds(expiredAfterSeconds))}; + + // Set the 'maxRecordIdTimestamp' parameter (upper scan boundary) to maximum possible. Whether + // the pre-image has to be deleted will be determined by the 'filter' parameter. + return _deleteExpiredChangeStreamPreImagesCommon( + opCtx, *preImageColl, &filter, Timestamp::max() /* maxRecordIdTimestamp */); +} } // namespace void ChangeStreamPreImagesCollectionManager::performExpiredChangeStreamPreImagesRemovalPass( @@ -342,9 +371,20 @@ void ChangeStreamPreImagesCollectionManager::performExpiredChangeStreamPreImages ServiceContext::UniqueOperationContext opCtx; try { opCtx = client->makeOperationContext(); + size_t numberOfRemovals = 0; + + if (change_stream_serverless_helpers::isChangeCollectionsModeActive()) { + const auto tenantIds = + change_stream_serverless_helpers::getConfigDbTenants(opCtx.get()); + for (const auto& tenantId : tenantIds) { + numberOfRemovals += deleteExpiredChangeStreamPreImagesForTenants( + opCtx.get(), tenantId, currentTimeForTimeBasedExpiration); + } + } else { + numberOfRemovals = + deleteExpiredChangeStreamPreImages(opCtx.get(), currentTimeForTimeBasedExpiration); + } - auto numberOfRemovals = - deleteExpiredChangeStreamPreImages(opCtx.get(), currentTimeForTimeBasedExpiration); if (numberOfRemovals > 0) { LOGV2_DEBUG(5869104, 3, diff --git a/src/mongo/db/clientcursor.cpp b/src/mongo/db/clientcursor.cpp index 2a471eb1288..c9994c1ea7f 100644 --- a/src/mongo/db/clientcursor.cpp +++ b/src/mongo/db/clientcursor.cpp @@ -40,6 +40,7 @@ #include "mongo/db/auth/privilege.h" #include "mongo/db/client.h" #include "mongo/db/commands.h" +#include "mongo/db/commands/external_data_source_scope_guard.h" #include "mongo/db/commands/server_status.h" #include "mongo/db/commands/server_status_metric.h" #include "mongo/db/curop.h" @@ -65,6 +66,10 @@ static CounterMetric cursorStatsTimedOut{"cursor.timedOut"}; static CounterMetric cursorStatsTotalOpened{"cursor.totalOpened"}; static CounterMetric cursorStatsMoreThanOneBatch{"cursor.moreThanOneBatch"}; +const ClientCursor::Decoration<std::shared_ptr<ExternalDataSourceScopeGuard>> + ExternalDataSourceScopeGuard::get = + ClientCursor::declareDecoration<std::shared_ptr<ExternalDataSourceScopeGuard>>(); + ClientCursor::ClientCursor(ClientCursorParams params, CursorId cursorId, OperationContext* operationUsingCursor, @@ -133,6 +138,9 @@ void ClientCursor::dispose(OperationContext* opCtx) { } _exec->dispose(opCtx); + // Update opCtx of the decorated ExternalDataSourceScopeGuard object so that it can drop virtual + // collections in the new 'opCtx'. + ExternalDataSourceScopeGuard::updateOperationContext(this, opCtx); _disposed = true; } diff --git a/src/mongo/db/clientcursor.h b/src/mongo/db/clientcursor.h index 89239e6230d..4e92e26d066 100644 --- a/src/mongo/db/clientcursor.h +++ b/src/mongo/db/clientcursor.h @@ -117,7 +117,7 @@ struct ClientCursorParams { * caller as "no timeout", it will be automatically destroyed by its cursor manager after a period * of inactivity. */ -class ClientCursor { +class ClientCursor : public Decorable<ClientCursor> { ClientCursor(const ClientCursor&) = delete; ClientCursor& operator=(const ClientCursor&) = delete; diff --git a/src/mongo/db/commands/SConscript b/src/mongo/db/commands/SConscript index 0074dc859c6..19de67d69b6 100644 --- a/src/mongo/db/commands/SConscript +++ b/src/mongo/db/commands/SConscript @@ -367,6 +367,7 @@ env.Library( '$BUILD_DIR/mongo/db/repl/replica_set_messages', '$BUILD_DIR/mongo/db/repl/tenant_migration_access_blocker', '$BUILD_DIR/mongo/db/rw_concern_d', + '$BUILD_DIR/mongo/db/s/query_analysis_writer', '$BUILD_DIR/mongo/db/server_base', '$BUILD_DIR/mongo/db/server_feature_flags', '$BUILD_DIR/mongo/db/session/session_catalog_mongod', @@ -789,6 +790,7 @@ env.CppUnitTest( target="db_commands_test", source=[ "create_indexes_test.cpp", + "external_data_source_commands_test.cpp", "index_filter_commands_test.cpp", "fle_compact_test.cpp", "list_collections_filter_test.cpp", @@ -819,7 +821,9 @@ env.CppUnitTest( "$BUILD_DIR/mongo/db/repl/replmocks", "$BUILD_DIR/mongo/db/repl/storage_interface_impl", "$BUILD_DIR/mongo/db/service_context_d_test_fixture", + "$BUILD_DIR/mongo/db/storage/record_store_base", '$BUILD_DIR/mongo/idl/idl_parser', + '$BUILD_DIR/mongo/util/version_impl', "cluster_server_parameter_commands_invocation", "core", "create_command", diff --git a/src/mongo/db/commands/count_cmd.cpp b/src/mongo/db/commands/count_cmd.cpp index 21ec11df16f..9a961c1ed0b 100644 --- a/src/mongo/db/commands/count_cmd.cpp +++ b/src/mongo/db/commands/count_cmd.cpp @@ -44,6 +44,7 @@ #include "mongo/db/query/plan_summary_stats.h" #include "mongo/db/query/view_response_formatter.h" #include "mongo/db/s/collection_sharding_state.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/logv2/log.h" #include "mongo/util/database_name_util.h" @@ -249,6 +250,15 @@ public: invocation->markMirrored(); } + if (analyze_shard_key::supportsPersistingSampledQueries() && request.getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addCountQuery(*request.getSampleId(), + nss, + request.getQuery(), + request.getCollation().value_or(BSONObj())) + .getAsync([](auto) {}); + } + if (ctx->getView()) { auto viewAggregation = countCommandAsAggregationCommand(request, nss); diff --git a/src/mongo/db/commands/dbcheck.cpp b/src/mongo/db/commands/dbcheck.cpp index 1b1110abb29..f8aafcae484 100644 --- a/src/mongo/db/commands/dbcheck.cpp +++ b/src/mongo/db/commands/dbcheck.cpp @@ -65,9 +65,8 @@ repl::OpTime _logOp(OperationContext* opCtx, repl::MutableOplogEntry oplogEntry; oplogEntry.setOpType(repl::OpTypeEnum::kCommand); oplogEntry.setNss(nss); - if (uuid) { - oplogEntry.setUuid(*uuid); - } + oplogEntry.setTid(nss.tenantId()); + oplogEntry.setUuid(uuid); oplogEntry.setObject(obj); AutoGetOplog oplogWrite(opCtx, OplogAccessMode::kWrite); return writeConflictRetry( @@ -144,7 +143,6 @@ struct DbCheckCollectionInfo { int64_t maxDocsPerBatch; int64_t maxBytesPerBatch; int64_t maxBatchTimeMillis; - bool snapshotRead; WriteConcernOptions writeConcern; }; @@ -167,6 +165,8 @@ std::unique_ptr<DbCheckRun> singleCollectionRun(OperationContext* opCtx, "Cannot run dbCheck on " + nss.toString() + " because it is not replicated", nss.isReplicated()); + uassert(6769500, "dbCheck no longer supports snapshotRead:false", invocation.getSnapshotRead()); + const auto start = invocation.getMinKey(); const auto end = invocation.getMaxKey(); const auto maxCount = invocation.getMaxCount(); @@ -184,7 +184,6 @@ std::unique_ptr<DbCheckRun> singleCollectionRun(OperationContext* opCtx, maxDocsPerBatch, maxBytesPerBatch, maxBatchTimeMillis, - invocation.getSnapshotRead(), invocation.getBatchWriteConcern()}; auto result = std::make_unique<DbCheckRun>(); result->push_back(info); @@ -201,6 +200,8 @@ std::unique_ptr<DbCheckRun> fullDatabaseRun(OperationContext* opCtx, AutoGetDb agd(opCtx, dbName, MODE_IS); uassert(ErrorCodes::NamespaceNotFound, "Database " + dbName.db() + " not found", agd.getDb()); + uassert(6769501, "dbCheck no longer supports snapshotRead:false", invocation.getSnapshotRead()); + const int64_t max = std::numeric_limits<int64_t>::max(); const auto rate = invocation.getMaxCountPerSecond(); const auto maxDocsPerBatch = invocation.getMaxDocsPerBatch(); @@ -220,7 +221,6 @@ std::unique_ptr<DbCheckRun> fullDatabaseRun(OperationContext* opCtx, maxDocsPerBatch, maxBytesPerBatch, maxBatchTimeMillis, - invocation.getSnapshotRead(), invocation.getBatchWriteConcern()}; result->push_back(info); return true; @@ -241,7 +241,8 @@ std::unique_ptr<DbCheckRun> getRun(OperationContext* opCtx, // Get rid of generic command fields. for (const auto& elem : obj) { - if (!isGenericArgument(elem.fieldNameStringData())) { + const auto& fieldName = elem.fieldNameStringData(); + if (!isGenericArgument(fieldName)) { builder.append(elem); } } @@ -251,11 +252,17 @@ std::unique_ptr<DbCheckRun> getRun(OperationContext* opCtx, // If the dbCheck argument is a string, this is the per-collection form. if (toParse["dbCheck"].type() == BSONType::String) { return singleCollectionRun( - opCtx, dbName, DbCheckSingleInvocation::parse(IDLParserContext(""), toParse)); + opCtx, + dbName, + DbCheckSingleInvocation::parse( + IDLParserContext("", false /*apiStrict*/, dbName.tenantId()), toParse)); } else { // Otherwise, it's the database-wide form. return fullDatabaseRun( - opCtx, dbName, DbCheckAllInvocation::parse(IDLParserContext(""), toParse)); + opCtx, + dbName, + DbCheckAllInvocation::parse( + IDLParserContext("", false /*apiStrict*/, dbName.tenantId()), toParse)); } } @@ -486,122 +493,77 @@ private: const BSONKey& first, int64_t batchDocs, int64_t batchBytes) { - auto lockMode = MODE_S; - if (info.snapshotRead) { - // Each batch will read at the latest no-overlap point, which is the all_durable - // timestamp on primaries. We assume that the history window on secondaries is always - // longer than the time it takes between starting and replicating a batch on the - // primary. Otherwise, the readTimestamp will not be available on a secondary by the - // time it processes the oplog entry. - lockMode = MODE_IS; - opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kNoOverlap); + // Each batch will read at the latest no-overlap point, which is the all_durable timestamp + // on primaries. We assume that the history window on secondaries is always longer than the + // time it takes between starting and replicating a batch on the primary. Otherwise, the + // readTimestamp will not be available on a secondary by the time it processes the oplog + // entry. + opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kNoOverlap); + + // dbCheck writes to the oplog, so we need to take an IX lock. We don't need to write to the + // collection, however, so we only take an intent lock on it. + Lock::GlobalLock glob(opCtx, MODE_IX); + AutoGetCollection collection(opCtx, info.nss, MODE_IS); + + if (_stepdownHasOccurred(opCtx, info.nss)) { + _done = true; + return Status(ErrorCodes::PrimarySteppedDown, "dbCheck terminated due to stepdown"); } - BatchStats result; - auto timeoutMs = Milliseconds(gDbCheckCollectionTryLockTimeoutMillis.load()); - const auto initialBackoffMs = - Milliseconds(gDbCheckCollectionTryLockMinBackoffMillis.load()); - auto backoffMs = initialBackoffMs; - for (int attempt = 1;; attempt++) { - try { - // Try to acquire collection lock with increasing timeout and bounded exponential - // backoff. - auto const lockDeadline = Date_t::now() + timeoutMs; - timeoutMs *= 2; - - AutoGetCollection agc( - opCtx, info.nss, lockMode, AutoGetCollection::Options{}.deadline(lockDeadline)); - - if (_stepdownHasOccurred(opCtx, info.nss)) { - _done = true; - return Status(ErrorCodes::PrimarySteppedDown, - "dbCheck terminated due to stepdown"); - } + if (!collection) { + const auto msg = "Collection under dbCheck no longer exists"; + return {ErrorCodes::NamespaceNotFound, msg}; + } - const auto& collection = - CollectionCatalog::get(opCtx)->lookupCollectionByNamespace(opCtx, info.nss); - if (!collection) { - const auto msg = "Collection under dbCheck no longer exists"; - return {ErrorCodes::NamespaceNotFound, msg}; - } + auto readTimestamp = opCtx->recoveryUnit()->getPointInTimeReadTimestamp(opCtx); + uassert(ErrorCodes::SnapshotUnavailable, + "No snapshot available yet for dbCheck", + readTimestamp); + auto minVisible = collection->getMinimumVisibleSnapshot(); + if (minVisible && *readTimestamp < *collection->getMinimumVisibleSnapshot()) { + return {ErrorCodes::SnapshotUnavailable, + str::stream() << "Unable to read from collection " << info.nss + << " due to pending catalog changes"}; + } - auto readTimestamp = opCtx->recoveryUnit()->getPointInTimeReadTimestamp(opCtx); - auto minVisible = collection->getMinimumVisibleSnapshot(); - if (readTimestamp && minVisible && - *readTimestamp < *collection->getMinimumVisibleSnapshot()) { - return {ErrorCodes::SnapshotUnavailable, - str::stream() << "Unable to read from collection " << info.nss - << " due to pending catalog changes"}; - } + boost::optional<DbCheckHasher> hasher; + try { + hasher.emplace(opCtx, + *collection, + first, + info.end, + std::min(batchDocs, info.maxCount), + std::min(batchBytes, info.maxSize)); + } catch (const DBException& e) { + return e.toStatus(); + } - boost::optional<DbCheckHasher> hasher; - try { - hasher.emplace(opCtx, - collection, - first, - info.end, - std::min(batchDocs, info.maxCount), - std::min(batchBytes, info.maxSize)); - } catch (const DBException& e) { - return e.toStatus(); - } + const auto batchDeadline = Date_t::now() + Milliseconds(info.maxBatchTimeMillis); + Status status = hasher->hashAll(opCtx, batchDeadline); - const auto batchDeadline = Date_t::now() + Milliseconds(info.maxBatchTimeMillis); - Status status = hasher->hashAll(opCtx, batchDeadline); + if (!status.isOK()) { + return status; + } - if (!status.isOK()) { - return status; - } + std::string md5 = hasher->total(); - std::string md5 = hasher->total(); - - DbCheckOplogBatch batch; - batch.setType(OplogEntriesEnum::Batch); - batch.setNss(info.nss); - batch.setMd5(md5); - batch.setMinKey(first); - batch.setMaxKey(BSONKey(hasher->lastKey())); - batch.setReadTimestamp(readTimestamp); - - // Send information on this batch over the oplog. - result.time = _logOp(opCtx, info.nss, collection->uuid(), batch.toBSON()); - result.readTimestamp = readTimestamp; - - result.nDocs = hasher->docsSeen(); - result.nBytes = hasher->bytesSeen(); - result.lastKey = hasher->lastKey(); - result.md5 = md5; - - break; - } catch (const ExceptionFor<ErrorCodes::LockTimeout>& e) { - if (attempt > gDbCheckCollectionTryLockMaxAttempts.load()) { - return StatusWith<BatchStats>(e.code(), - "Unable to acquire the collection lock"); - } + DbCheckOplogBatch batch; + batch.setType(OplogEntriesEnum::Batch); + batch.setNss(info.nss); + batch.setMd5(md5); + batch.setMinKey(first); + batch.setMaxKey(BSONKey(hasher->lastKey())); + batch.setReadTimestamp(readTimestamp); - // Bounded exponential backoff between tryLocks. - opCtx->sleepFor(backoffMs); - const auto maxBackoffMillis = - Milliseconds(gDbCheckCollectionTryLockMaxBackoffMillis.load()); - if (backoffMs < maxBackoffMillis) { - auto backoff = durationCount<Milliseconds>(backoffMs); - auto initialBackoff = durationCount<Milliseconds>(initialBackoffMs); - backoff *= initialBackoff; - backoffMs = Milliseconds(backoff); - } - if (backoffMs > maxBackoffMillis) { - backoffMs = maxBackoffMillis; - } - LOGV2_DEBUG(6175700, - 1, - "Could not acquire collection lock, retrying", - "ns"_attr = info.nss.ns(), - "batchRangeMin"_attr = info.start.obj(), - "batchRangeMax"_attr = info.end.obj(), - "attempt"_attr = attempt, - "backoff"_attr = backoffMs); - } - } + // Send information on this batch over the oplog. + BatchStats result; + result.time = _logOp(opCtx, info.nss, collection->uuid(), batch.toBSON()); + result.readTimestamp = readTimestamp; + + result.nDocs = hasher->docsSeen(); + result.nBytes = hasher->bytesSeen(); + result.lastKey = hasher->lastKey(); + result.md5 = md5; return result; } @@ -663,7 +625,6 @@ public: " maxDocsPerBatch: <max number of docs/batch>\n" " maxBytesPerBatch: <try to keep a batch within max bytes/batch>\n" " maxBatchTimeMillis: <max time processing a batch in milliseconds>\n" - " readTimestamp: <bool, read at a timestamp without strong locks> }\n" "to check a collection.\n" "Invoke with {dbCheck: 1} to check all collections in the database."; } diff --git a/src/mongo/db/commands/dbcommands.cpp b/src/mongo/db/commands/dbcommands.cpp index 6374a093936..0b6f90d0203 100644 --- a/src/mongo/db/commands/dbcommands.cpp +++ b/src/mongo/db/commands/dbcommands.cpp @@ -485,6 +485,10 @@ public: return Request::kCommandDescription.toString(); } + bool allowedWithSecurityToken() const final { + return true; + } + // Assume that appendCollectionStorageStats() gives us a valid response. void validateResult(const BSONObj& resultObj) final {} @@ -587,6 +591,10 @@ public: CmdDbStats() : TypedCommand(Request::kCommandName, Request::kCommandAlias) {} + bool allowedWithSecurityToken() const final { + return true; + } + class Invocation final : public InvocationBase { public: using InvocationBase::InvocationBase; diff --git a/src/mongo/db/commands/dbcommands_d.cpp b/src/mongo/db/commands/dbcommands_d.cpp index 5f42cb45010..5ad0adf46c5 100644 --- a/src/mongo/db/commands/dbcommands_d.cpp +++ b/src/mongo/db/commands/dbcommands_d.cpp @@ -248,6 +248,10 @@ public: return Status::OK(); } + bool allowedWithSecurityToken() const final { + return true; + } + bool run(OperationContext* opCtx, const DatabaseName& dbName, const BSONObj& jsobj, diff --git a/src/mongo/db/commands/distinct.cpp b/src/mongo/db/commands/distinct.cpp index 070da8d285c..eb599424001 100644 --- a/src/mongo/db/commands/distinct.cpp +++ b/src/mongo/db/commands/distinct.cpp @@ -57,6 +57,7 @@ #include "mongo/db/query/query_planner_common.h" #include "mongo/db/query/view_response_formatter.h" #include "mongo/db/s/collection_sharding_state.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/views/resolved_view.h" #include "mongo/logv2/log.h" #include "mongo/util/database_name_util.h" @@ -244,6 +245,16 @@ public: invocation->markMirrored(); } + if (analyze_shard_key::supportsPersistingSampledQueries() && parsedDistinct.getSampleId()) { + auto cq = parsedDistinct.getQuery(); + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addDistinctQuery(*parsedDistinct.getSampleId(), + nss, + cq->getQueryObj(), + cq->getFindCommandRequest().getCollation()) + .getAsync([](auto) {}); + } + if (ctx->getView()) { // Relinquish locks. The aggregation command will re-acquire them. ctx.reset(); diff --git a/src/mongo/db/commands/external_data_source_commands_test.cpp b/src/mongo/db/commands/external_data_source_commands_test.cpp new file mode 100644 index 00000000000..113495efdc8 --- /dev/null +++ b/src/mongo/db/commands/external_data_source_commands_test.cpp @@ -0,0 +1,953 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <fmt/format.h> + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/client/dbclient_cursor.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/pipeline/aggregation_request_helper.h" +#include "mongo/db/query/query_knobs_gen.h" +#include "mongo/db/repl/replication_coordinator_mock.h" +#include "mongo/db/service_context_d_test_fixture.h" +#include "mongo/db/storage/named_pipe.h" +#include "mongo/util/assert_util.h" +#include "mongo/util/scopeguard.h" + +namespace mongo { +namespace { +using namespace fmt::literals; + +class PipeWaiter { +public: + void notify() { + { + stdx::unique_lock lk(m); + pipeCreated = true; + } + cv.notify_one(); + } + + void wait() { + stdx::unique_lock lk(m); + cv.wait(lk, [&] { return pipeCreated; }); + } + +private: + Mutex m; + stdx::condition_variable cv; + bool pipeCreated = false; +}; + +class ExternalDataSourceCommandsTest : public ServiceContextMongoDTest { +protected: + void setUp() override { + ServiceContextMongoDTest::setUp(); + + std::srand(std::time(0)); + + const auto service = getServiceContext(); + auto replCoord = + std::make_unique<repl::ReplicationCoordinatorMock>(service, repl::ReplSettings{}); + ASSERT_OK(replCoord->setFollowerMode(repl::MemberState::RS_PRIMARY)); + repl::ReplicationCoordinator::set(service, std::move(replCoord)); + repl::createOplog(_opCtx); + + computeModeEnabled = true; + } + + void tearDown() override { + computeModeEnabled = false; + ServiceContextMongoDTest::tearDown(); + } + + std::vector<BSONObj> generateRandomSimpleDocs(int count) { + std::vector<BSONObj> docs; + for (int i = 0; i < count; ++i) { + docs.emplace_back(BSON("a" << std::rand() % 10)); + } + + return docs; + } + + // Generates a large readable random string to aid debugging. + std::string getRandomReadableLargeString() { + int count = std::rand() % 100 + 2024; + std::string str(count, '\0'); + for (int i = 0; i < count; ++i) { + str[i] = static_cast<char>(std::rand() % 26) + 'a'; + } + + return str; + } + + std::vector<BSONObj> generateRandomLargeDocs(int count) { + std::vector<BSONObj> docs; + for (int i = 0; i < count; ++i) { + docs.emplace_back(BSON("a" << getRandomReadableLargeString())); + } + + return docs; + } + + // This verifies that a simple explain aggregate command works. Virtual collections are created + // even for explain aggregate command. + void verifyExplainAggCommand(DBDirectClient& client, const BSONObj& explainAggCmdObj) { + // The first request. + BSONObj res; + ASSERT_TRUE(client.runCommand(kDatabaseName, explainAggCmdObj.getOwned(), res)) + << "Expected to succeed but failed. result = {}"_format(res.toString()); + // Sanity checks of result. + ASSERT_EQ(res["ok"].Number(), 1.0) + << "Expected to succeed but failed. result = {}"_format(res.toString()); + } + + ServiceContext::UniqueOperationContext _uniqueOpCtx{makeOperationContext()}; + OperationContext* _opCtx{_uniqueOpCtx.get()}; + + BSONObj explainSingleNamedPipeAggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + explain: true, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + BSONObj explainMultipleNamedPipesAggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + explain: true, + $_externalDataSources: [{ + collName: "coll", + dataSources: [ + {url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}, + {url: "file://named_pipe2", storageType: "pipe", fileType: "bson"} + ] + }] +} + )"); + + static constexpr auto kDatabaseName = "external_data_source"; +}; + +TEST_F(ExternalDataSourceCommandsTest, SimpleScanAggRequest) { + const auto nDocs = std::rand() % 100 + 1; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + // The first request. + BSONObj res; + ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res)); + + // Sanity checks of result. + ASSERT_EQ(res["ok"].Number(), 1.0); + ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch")); + + // The default batch size is 101 and so all data must be contained in the first batch. cursor.id + // == 0 means that no cursor is necessary. + ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0); + auto resDocs = res["cursor"]["firstBatch"].Array(); + ASSERT_EQ(resDocs.size(), nDocs); + for (int i = 0; i < nDocs; ++i) { + ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i]); + } + + // The second request. This verifies that virtual collections are cleaned up after the + // aggregation request is done. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, SimpleScanOverMultipleNamedPipesAggRequest) { + // This data set fits into the first batch. + const auto nDocs = std::rand() % 50; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + // Pushes data into multiple named pipes. We can't push data into multiple named pipes + // simultaneously because writers will be blocked until the reader consumes data. So, we push + // data into one named pipe after another. + stdx::thread producer([&] { + NamedPipeOutput pipeWriter1("named_pipe1"); + NamedPipeOutput pipeWriter2("named_pipe2"); + pw.notify(); + pipeWriter1.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter1.close(); + + pipeWriter2.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter2.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [ + {url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}, + {url: "file://named_pipe2", storageType: "pipe", fileType: "bson"} + ] + }] +} + )"); + + // The first request. + BSONObj res; + ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res)); + + // Sanity checks of result. + ASSERT_EQ(res["ok"].Number(), 1.0); + ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch")); + + // The default batch size is 101 and so all data must be contained in the first batch. cursor.id + // == 0 means that no cursor is necessary. + ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0); + auto resDocs = res["cursor"]["firstBatch"].Array(); + ASSERT_EQ(resDocs.size(), nDocs * 2); + for (int i = 0; i < nDocs; ++i) { + ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i % nDocs]); + } + + // The second request. This verifies that virtual collections are cleaned up after the + // aggregation request is done. + verifyExplainAggCommand(client, explainMultipleNamedPipesAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, SimpleScanOverLargeObjectsAggRequest) { + // MultiBsonStreamCursor's default buffer size is 8K and 2K (at minimum) * 20 would be enough to + // exceed the initial read. This data set is highly likely to span multiple reads. + const auto nDocs = std::rand() % 80 + 20; + std::vector<BSONObj> srcDocs = generateRandomLargeDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + // The first request. + BSONObj res; + ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res)); + + // Sanity checks of result. + ASSERT_EQ(res["ok"].Number(), 1.0); + ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch")); + + // The default batch size is 101 and so all data must be contained in the first batch. cursor.id + // == 0 means that no cursor is necessary. + ASSERT_TRUE(res["cursor"].Obj().hasField("id") && res["cursor"]["id"].Long() == 0); + auto resDocs = res["cursor"]["firstBatch"].Array(); + ASSERT_EQ(resDocs.size(), nDocs); + for (int i = 0; i < nDocs; ++i) { + ASSERT_BSONOBJ_EQ(resDocs[i].Obj(), srcDocs[i]); + } + + // The second request. This verifies that virtual collections are cleaned up after the + // aggregation request is done. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +// Tests that 'explain' flag works and also tests that the same aggregation request works with the +// same $_externalDataSources again to see whether there are no remaining virtual collections left +// behind after the aggregation request is done. +TEST_F(ExternalDataSourceCommandsTest, ExplainAggRequest) { + DBDirectClient client(_opCtx); + // The first request. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); + + // The second request. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, SimpleScanMultiBatchAggRequest) { + // This 'nDocs' causes a cursor to be created for a simple scan aggregate command. + const auto nDocs = std::rand() % 100 + 102; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + // While iterating over the cursor, getMore() request(s) will be sent and the server-side cursor + // will be destroyed after all data is exhausted. + while (cursor->more()) { + auto doc = cursor->next(); + ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt]); + ++resCnt; + } + ASSERT_EQ(resCnt, nDocs); + + // The second explain request. This verifies that virtual collections are cleaned up after + // multi-batch result for an aggregation request. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, SimpleMatchAggRequest) { + const auto nDocs = std::rand() % 100 + 1; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + // Expected results for {$match: {a: {$lt: 5}}}. + std::vector<BSONObj> expectedDocs; + std::for_each(srcDocs.begin(), srcDocs.end(), [&](const BSONObj& doc) { + if (doc["a"].Int() < 5) { + expectedDocs.emplace_back(doc); + } + }); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [{$match: {a: {$lt: 5}}}], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + while (cursor->more()) { + auto doc = cursor->next(); + ASSERT_BSONOBJ_EQ(doc, expectedDocs[resCnt]); + ++resCnt; + } + ASSERT_EQ(resCnt, expectedDocs.size()); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the aggregation request is done. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, ScanOverRandomInvalidDataAggRequest) { + const auto nDocs = std::rand() % 100 + 1; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + const size_t failPoint = std::rand() % nDocs; + pipeWriter.open(); + for (size_t i = 0; i < srcDocs.size(); ++i) { + if (i == failPoint) { + // Intentionally pushes invalid data at the fail point so that an error happens at + // the reader-side + pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize() / 2); + } else { + pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize()); + } + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [{$match: {a: {$lt: 5}}}], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + BSONObj res; + ASSERT_FALSE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res)); + ASSERT_EQ(res["ok"].Number(), 0.0); + // The fail point is randomly chosen and different error codes are expected, depending on the + // chosen fail point. + ASSERT_NE(ErrorCodes::Error(res["code"].Int()), ErrorCodes::OK); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the aggregation request fails. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, ScanOverRandomInvalidDataAtSecondBatchAggRequest) { + // This 'nDocs' causes a cursor to be created for a simple scan aggregate command. + const auto nDocs = std::rand() % 100 + 102; // 201 >= nDocs >= 102 + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + // The fail point occurs at the second batch. + const size_t failPoint = 101 + std::rand() % (nDocs - 101); // 200 >= failPoint >= 101 + pipeWriter.open(); + for (size_t i = 0; i < srcDocs.size(); ++i) { + if (i == failPoint) { + // Intentionally pushes invalid data at the fail point so that an error happens at + // the reader-side + pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize() / 2); + } else { + pipeWriter.write(srcDocs[i].objdata(), srcDocs[i].objsize()); + } + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + bool errorOccurred = false; + try { + while (cursor->more()) { + auto doc = cursor->next(); + ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt]); + ++resCnt; + } + } catch (const DBException& ex) { + errorOccurred = true; + ASSERT_NE(ex.code(), ErrorCodes::OK); + } + ASSERT_TRUE(errorOccurred); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the getMore request for the aggregation results fails. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, KillCursorAfterAggRequest) { + // This 'nDocs' causes a cursor to be created for a simple scan aggregate command. + const auto nDocs = std::rand() % 100 + 102; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + // The first request. + BSONObj res; + ASSERT_TRUE(client.runCommand(kDatabaseName, aggCmdObj.getOwned(), res)); + + // Sanity checks of result. + ASSERT_EQ(res["ok"].Number(), 1.0); + ASSERT_TRUE(res.hasField("cursor") && res["cursor"].Obj().hasField("firstBatch")); + + // The default batch size is 101 and results can be returned through multiple batches. cursor.id + // != 0 means that a cursor is created. + auto cursorId = res["cursor"]["id"].Long(); + ASSERT_TRUE(res["cursor"].Obj().hasField("id") && cursorId != 0); + + // Kills the cursor. + auto killCursorCmdObj = BSON("killCursors" + << "coll" + << "cursors" << BSON_ARRAY(cursorId)); + ASSERT_TRUE(client.runCommand(kDatabaseName, killCursorCmdObj.getOwned(), res)); + ASSERT_EQ(res["ok"].Number(), 1.0); + auto cursorsKilled = res["cursorsKilled"].Array(); + ASSERT_TRUE(cursorsKilled.size() == 1 && cursorsKilled[0].Long() == cursorId); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the cursor for the aggregate request is killed. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, SimpleScanAndUnionWithMultipleSourcesAggRequest) { + const auto nDocs = std::rand() % 100 + 1; + std::vector<BSONObj> srcDocs = generateRandomSimpleDocs(nDocs); + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter1("named_pipe1"); + pw.notify(); + pipeWriter1.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter1.close(); + + NamedPipeOutput pipeWriter2("named_pipe2"); + pipeWriter2.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter2.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + // An aggregate request with a simple scan and $unionWith stage. $_externalDataSources option + // defines multiple data sources. + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll1", + pipeline: [{$unionWith: "coll2"}], + cursor: {}, + $_externalDataSources: [{ + collName: "coll1", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }, { + collName: "coll2", + dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + while (cursor->more()) { + auto doc = cursor->next(); + // Simple scan from 'coll1' first and then $unionWith from 'coll2'. + ASSERT_BSONOBJ_EQ(doc, srcDocs[resCnt % nDocs]); + ++resCnt; + } + ASSERT_EQ(resCnt, nDocs * 2); + + auto explainAggCmdObj = fromjson(R"( +{ + aggregate: "coll1", + pipeline: [{$unionWith: "coll2"}], + explain: true, + $_externalDataSources: [{ + collName: "coll1", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }, { + collName: "coll2", + dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the aggregation request is done. + verifyExplainAggCommand(client, explainAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, GroupAggRequest) { + std::vector<BSONObj> srcDocs = { + fromjson(R"( + { + "_id" : 1, + "item" : "a", + "quantity" : 2 + })"), + fromjson(R"( + { + "_id" : 2, + "item" : "b", + "quantity" : 1 + })"), + fromjson(R"( + { + "_id" : 3, + "item" : "a", + "quantity" : 5 + })"), + fromjson(R"( + { + "_id" : 4, + "item" : "b", + "quantity" : 10 + })"), + fromjson(R"( + { + "_id" : 5, + "item" : "c", + "quantity" : 10 + })"), + }; + PipeWaiter pw; + + stdx::thread producer([&] { + NamedPipeOutput pipeWriter("named_pipe1"); + pw.notify(); + pipeWriter.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll", + pipeline: [{$group: {_id: "$item", o: {$sum: "$quantity"}}}], + cursor: {}, + $_externalDataSources: [{ + collName: "coll", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + std::vector<BSONObj> expectedRes = { + fromjson(R"( + { + "_id" : "a", + "o" : 7 + })"), + fromjson(R"( + { + "_id" : "b", + "o" : 11 + })"), + fromjson(R"( + { + "_id" : "c", + "o" : 10 + })"), + }; + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + while (cursor->more()) { + auto doc = cursor->next(); + // Result set is pretty small and so we use linear search of vector. + ASSERT_TRUE( + std::find_if(expectedRes.begin(), expectedRes.end(), [&](const BSONObj& expectedObj) { + return expectedObj.objsize() == doc.objsize() && + std::memcmp(expectedObj.objdata(), doc.objdata(), expectedObj.objsize()) == 0; + }) != expectedRes.end()); + ++resCnt; + } + ASSERT_EQ(resCnt, expectedRes.size()); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the aggregation request is done. + verifyExplainAggCommand(client, explainSingleNamedPipeAggCmdObj); +} + +TEST_F(ExternalDataSourceCommandsTest, LookupAggRequest) { + std::vector<BSONObj> srcDocs = { + fromjson(R"( + { + "a" : 1, + "data" : "abcd" + })"), + fromjson(R"( + { + "a" : 2, + "data" : "efgh" + })"), + fromjson(R"( + { + "a" : 3, + "data" : "ijkl" + })"), + }; + PipeWaiter pw; + + // For the $lookup stage, we need data to be available for both named pipes simultaneously + // because $lookup would read data from both collections and so we use two different named + // pipes and pushes data into the inner side first. To avoid racy condition, notify the reader + // side after both named pipes are created. This order is geared toward hash join algorithm. + stdx::thread producer([&] { + NamedPipeOutput pipeWriter2("named_pipe2"); + NamedPipeOutput pipeWriter1("named_pipe1"); + pw.notify(); + + // Pushes data into the inner side (== coll2 with named_pipe2) first because the hash join + // builds the inner (or build) side first. + pipeWriter2.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter2.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter2.close(); + + pipeWriter1.open(); + for (auto&& srcDoc : srcDocs) { + pipeWriter1.write(srcDoc.objdata(), srcDoc.objsize()); + } + pipeWriter1.close(); + }); + ON_BLOCK_EXIT([&] { producer.join(); }); + + // Gives some time to the producer so that it can initialize a named pipe. + pw.wait(); + + DBDirectClient client(_opCtx); + auto aggCmdObj = fromjson(R"( +{ + aggregate: "coll1", + pipeline: [{$lookup: {from: "coll2", localField: "a", foreignField: "a", as: "o"}}], + cursor: {}, + $_externalDataSources: [{ + collName: "coll1", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }, { + collName: "coll2", + dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + std::vector<BSONObj> expectedRes = { + fromjson(R"( + { + "a" : 1, + "data" : "abcd", + "o" : [{"a": 1, "data": "abcd"}] + })"), + fromjson(R"( + { + "a" : 2, + "data" : "efgh", + "o" : [{"a": 2, "data": "efgh"}] + })"), + fromjson(R"( + { + "a" : 3, + "data" : "ijkl", + "o" : [{"a": 3, "data": "ijkl"}] + })"), + }; + + auto swAggReq = aggregation_request_helper::parseFromBSONForTests(kDatabaseName, aggCmdObj); + ASSERT_OK(swAggReq.getStatus()); + auto swCursor = DBClientCursor::fromAggregationRequest( + &client, swAggReq.getValue(), /*secondaryOk*/ false, /*useExhaust*/ false); + ASSERT_OK(swCursor.getStatus()); + + auto cursor = std::move(swCursor.getValue()); + int resCnt = 0; + while (cursor->more()) { + auto doc = cursor->next(); + // Result set is pretty small and so we use linear search of vector. + ASSERT_TRUE( + std::find_if(expectedRes.begin(), expectedRes.end(), [&](const BSONObj& expectedObj) { + return expectedObj.objsize() == doc.objsize() && + std::memcmp(expectedObj.objdata(), doc.objdata(), expectedObj.objsize()) == 0; + }) != expectedRes.end()); + ++resCnt; + } + ASSERT_EQ(resCnt, expectedRes.size()); + + auto explainAggCmdObj = fromjson(R"( +{ + aggregate: "coll1", + pipeline: [{$lookup: {from: "coll2", localField: "a", foreignField: "a", as: "o"}}], + explain: true, + $_externalDataSources: [{ + collName: "coll1", + dataSources: [{url: "file://named_pipe1", storageType: "pipe", fileType: "bson"}] + }, { + collName: "coll2", + dataSources: [{url: "file://named_pipe2", storageType: "pipe", fileType: "bson"}] + }] +} + )"); + + // The second explain request. This verifies that virtual collections are cleaned up after + // the aggregation request is done. + verifyExplainAggCommand(client, explainAggCmdObj); +} +} // namespace +} // namespace mongo diff --git a/src/mongo/db/commands/external_data_source_scope_guard.cpp b/src/mongo/db/commands/external_data_source_scope_guard.cpp new file mode 100644 index 00000000000..f9a29fcba67 --- /dev/null +++ b/src/mongo/db/commands/external_data_source_scope_guard.cpp @@ -0,0 +1,87 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/commands/external_data_source_scope_guard.h" + +#include "mongo/db/catalog/create_collection.h" +#include "mongo/db/catalog/drop_collection.h" +#include "mongo/db/catalog/virtual_collection_options.h" +#include "mongo/db/drop_gen.h" +#include "mongo/db/namespace_string.h" +#include "mongo/db/pipeline/external_data_source_option_gen.h" +#include "mongo/logv2/log.h" +#include "mongo/util/destructor_guard.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery + +namespace mongo { +ExternalDataSourceScopeGuard::ExternalDataSourceScopeGuard( + OperationContext* opCtx, + const std::vector<std::pair<NamespaceString, std::vector<ExternalDataSourceInfo>>>& + usedExternalDataSources) + : _opCtx(opCtx) { + // Just in case that any virtual collection could not be created, when dtor does not have a + // chance to be executed, cleans up collections that has already been created at that + // moment. + ScopeGuard dropVcollGuard([&] { dropVirtualCollections(); }); + + for (auto&& [extDataSourceNss, dataSources] : usedExternalDataSources) { + VirtualCollectionOptions vopts(dataSources); + uassertStatusOK(createVirtualCollection(opCtx, extDataSourceNss, vopts)); + _toBeDroppedVirtualCollections.emplace_back(extDataSourceNss); + } + + dropVcollGuard.dismiss(); +} + +void ExternalDataSourceScopeGuard::dropVirtualCollections() noexcept { + // The move constructor sets '_opCtx' to null when ownership is moved to the other object which + // means this object must not try to drop collections. There's nothing to drop if '_opCtx' is + // null. + if (!_opCtx) { + return; + } + + // This function is called in a context of destructor or exception and so guard this against any + // exceptions. + DESTRUCTOR_GUARD({ + for (auto&& nss : _toBeDroppedVirtualCollections) { + DropReply reply; + auto status = + dropCollection(_opCtx, + nss, + &reply, + DropCollectionSystemCollectionMode::kDisallowSystemCollectionDrops); + if (!status.isOK()) { + LOGV2_ERROR(6968700, "Failed to drop an external data source", "coll"_attr = nss); + } + } + }); +} +} // namespace mongo diff --git a/src/mongo/db/commands/external_data_source_scope_guard.h b/src/mongo/db/commands/external_data_source_scope_guard.h new file mode 100644 index 00000000000..21b48890432 --- /dev/null +++ b/src/mongo/db/commands/external_data_source_scope_guard.h @@ -0,0 +1,84 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/clientcursor.h" +#include "mongo/db/namespace_string.h" +#include "mongo/db/operation_context.h" +#include "mongo/db/pipeline/external_data_source_option_gen.h" + +namespace mongo { +/** + * This class makes sure that virtual collections that are created for external data sources are + * dropped when it's destroyed. + */ +class ExternalDataSourceScopeGuard { +public: + // Makes ExternalDataSourceScopeGuard a decoration of ClientCursor. + static const ClientCursor::Decoration<std::shared_ptr<ExternalDataSourceScopeGuard>> get; + + // Updates the operation context of decorated ExternalDataSourceScopeGuard object of 'cursor' + // so that it can drop virtual collections in the new 'opCtx'. + static void updateOperationContext(const ClientCursor* cursor, OperationContext* opCtx) { + if (auto self = get(cursor); self) { + get(cursor)->_opCtx = opCtx; + } + } + + ExternalDataSourceScopeGuard() : _opCtx(nullptr), _toBeDroppedVirtualCollections() {} + + ExternalDataSourceScopeGuard( + OperationContext* opCtx, + const std::vector<std::pair<NamespaceString, std::vector<ExternalDataSourceInfo>>>& + usedExternalDataSources); + + // It does not make sense to support copy ctor because this object must drop created virtual + // collections. + ExternalDataSourceScopeGuard(const ExternalDataSourceScopeGuard&) = delete; + + ExternalDataSourceScopeGuard(ExternalDataSourceScopeGuard&& other) noexcept + : _opCtx(other._opCtx), + _toBeDroppedVirtualCollections(std::move(other._toBeDroppedVirtualCollections)) { + // Ownership of created virtual collections are moved to this object and the other object + // must not try to drop them any more. + other._opCtx = nullptr; + } + + ~ExternalDataSourceScopeGuard() { + dropVirtualCollections(); + } + +private: + void dropVirtualCollections() noexcept; + + OperationContext* _opCtx; + std::vector<NamespaceString> _toBeDroppedVirtualCollections; +}; +} // namespace mongo diff --git a/src/mongo/db/commands/fail_point_cmd.cpp b/src/mongo/db/commands/fail_point_cmd.cpp index 7e868735f5b..a067522d6fe 100644 --- a/src/mongo/db/commands/fail_point_cmd.cpp +++ b/src/mongo/db/commands/fail_point_cmd.cpp @@ -93,6 +93,10 @@ public: return Status::OK(); } + bool allowedWithSecurityToken() const final { + return true; + } + std::string help() const override { return "modifies the settings of a fail point"; } diff --git a/src/mongo/db/commands/find_and_modify.cpp b/src/mongo/db/commands/find_and_modify.cpp index 1ba9efa0a75..ec70924da40 100644 --- a/src/mongo/db/commands/find_and_modify.cpp +++ b/src/mongo/db/commands/find_and_modify.cpp @@ -60,6 +60,7 @@ #include "mongo/db/repl/replication_coordinator.h" #include "mongo/db/s/collection_sharding_state.h" #include "mongo/db/s/operation_sharding_state.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/stats/counters.h" #include "mongo/db/stats/resource_consumption_metrics.h" #include "mongo/db/stats/top.h" @@ -76,6 +77,7 @@ namespace mongo { namespace { +MONGO_FAIL_POINT_DEFINE(failAllFindAndModify); MONGO_FAIL_POINT_DEFINE(hangBeforeFindAndModifyPerformsUpdate); /** @@ -682,6 +684,16 @@ write_ops::FindAndModifyCommandReply CmdFindAndModify::Invocation::typedRun( } } + if (analyze_shard_key::supportsPersistingSampledQueries() && request().getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addFindAndModifyQuery(request()) + .getAsync([](auto) {}); + } + + if (MONGO_unlikely(failAllFindAndModify.shouldFail())) { + uasserted(ErrorCodes::InternalError, "failAllFindAndModify failpoint active!"); + } + const bool inTransaction = opCtx->inMultiDocumentTransaction(); // Although usually the PlanExecutor handles WCE internally, it will throw WCEs when it @@ -711,6 +723,7 @@ write_ops::FindAndModifyCommandReply CmdFindAndModify::Invocation::typedRun( if (opCtx->getTxnNumber()) { updateRequest.setStmtIds({stmtId}); } + updateRequest.setSampleId(req.getSampleId()); const ExtensionsCallbackReal extensionsCallback( opCtx, &updateRequest.getNamespaceString()); diff --git a/src/mongo/db/commands/find_cmd.cpp b/src/mongo/db/commands/find_cmd.cpp index d048a2ec616..1188981c991 100644 --- a/src/mongo/db/commands/find_cmd.cpp +++ b/src/mongo/db/commands/find_cmd.cpp @@ -56,6 +56,7 @@ #include "mongo/db/query/get_executor.h" #include "mongo/db/query/query_knobs_gen.h" #include "mongo/db/repl/replication_coordinator.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/service_context.h" #include "mongo/db/stats/counters.h" #include "mongo/db/stats/resource_consumption_metrics.h" @@ -65,6 +66,7 @@ #include "mongo/db/transaction/transaction_participant.h" #include "mongo/logv2/log.h" #include "mongo/rpc/get_status_from_command_result.h" +#include "mongo/util/assert_util.h" #include "mongo/util/database_name_util.h" #include "mongo/util/fail_point.h" @@ -362,6 +364,8 @@ public: // execution tree with an EOFStage. const auto& collection = ctx->getCollection(); + cq->setUseCqfIfEligible(true); + // Get the execution plan for the query. bool permitYield = true; auto exec = @@ -433,10 +437,10 @@ public: // The presence of a term in the request indicates that this is an internal replication // oplog read request. if (term && isOplogNss) { - // We do not want to take tickets for internal (replication) oplog reads. Stalling - // on ticket acquisition can cause complicated deadlocks. Primaries may depend on - // data reaching secondaries in order to proceed; and secondaries may get stalled - // replicating because of an inability to acquire a read ticket. + // We do not want to wait to take tickets for internal (replication) oplog reads. + // Stalling on ticket acquisition can cause complicated deadlocks. Primaries may + // depend on data reaching secondaries in order to proceed; and secondaries may get + // stalled replicating because of an inability to acquire a read ticket. opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate); } @@ -481,6 +485,16 @@ public: .expectedUUID(findCommand->getCollectionUUID())); const auto& nss = ctx->getNss(); + if (analyze_shard_key::supportsPersistingSampledQueries() && + findCommand->getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addFindQuery(*findCommand->getSampleId(), + nss, + findCommand->getFilter(), + findCommand->getCollation()) + .getAsync([](auto) {}); + } + // Going forward this operation must never ignore interrupt signals while waiting for // lock acquisition. This InterruptibleLockGuard will ensure that waiting for lock // re-acquisition after yielding will not ignore interrupt signals. This is necessary to @@ -542,10 +556,10 @@ public: const auto& findCommand = cq->getFindCommandRequest(); auto viewAggregationCommand = uassertStatusOK(query_request_helper::asAggregationCommand(findCommand)); - - BSONObj aggResult = CommandHelpers::runCommandDirectly( - opCtx, - OpMsgRequest::fromDBAndBody(_dbName.db(), std::move(viewAggregationCommand))); + auto aggRequest = + OpMsgRequestBuilder::create(_dbName, std::move(viewAggregationCommand)); + aggRequest.validatedTenancyScope = _request.validatedTenancyScope; + BSONObj aggResult = CommandHelpers::runCommandDirectly(opCtx, aggRequest); auto status = getStatusFromCommandResult(aggResult); if (status.code() == ErrorCodes::InvalidPipelineOperator) { uasserted(ErrorCodes::InvalidPipelineOperator, @@ -572,6 +586,8 @@ public: opCtx->recoveryUnit()->setReadOnce(true); } + cq->setUseCqfIfEligible(true); + // Get the execution plan for the query. bool permitYield = true; auto exec = diff --git a/src/mongo/db/commands/generic.cpp b/src/mongo/db/commands/generic.cpp index c5f8dde7e74..9bc1717f6e8 100644 --- a/src/mongo/db/commands/generic.cpp +++ b/src/mongo/db/commands/generic.cpp @@ -173,6 +173,10 @@ public: return false; } + bool allowedWithSecurityToken() const final { + return true; + } + bool run(OperationContext* opCtx, const DatabaseName&, const BSONObj& cmdObj, diff --git a/src/mongo/db/commands/getmore_cmd.cpp b/src/mongo/db/commands/getmore_cmd.cpp index 3d6c73db66f..09aa73cc315 100644 --- a/src/mongo/db/commands/getmore_cmd.cpp +++ b/src/mongo/db/commands/getmore_cmd.cpp @@ -40,6 +40,7 @@ #include "mongo/db/client.h" #include "mongo/db/clientcursor.h" #include "mongo/db/commands.h" +#include "mongo/db/commands/external_data_source_scope_guard.h" #include "mongo/db/curop.h" #include "mongo/db/curop_failpoint_helpers.h" #include "mongo/db/cursor_manager.h" @@ -485,6 +486,10 @@ public: setUpOperationContextStateForGetMore( opCtx, *cursorPin.getCursor(), _cmd, disableAwaitDataFailpointActive); + // Update opCtx of the decorated ExternalDataSourceScopeGuard object so that it can drop + // virtual collections in the new 'opCtx'. + ExternalDataSourceScopeGuard::updateOperationContext(cursorPin.getCursor(), opCtx); + // On early return, typically due to a failed assertion, delete the cursor. ScopeGuard cursorDeleter([&] { cursorPin.deleteUnderlying(); }); @@ -717,10 +722,10 @@ public: // internal clients (see checkAuthForGetMore). curOp->debug().isReplOplogGetMore = true; - // We do not want to take tickets for internal (replication) oplog reads. Stalling - // on ticket acquisition can cause complicated deadlocks. Primaries may depend on - // data reaching secondaries in order to proceed; and secondaries may get stalled - // replicating because of an inability to acquire a read ticket. + // We do not want to wait to take tickets for internal (replication) oplog reads. + // Stalling on ticket acquisition can cause complicated deadlocks. Primaries may + // depend on data reaching secondaries in order to proceed; and secondaries may get + // stalled replicating because of an inability to acquire a read ticket. opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate); } diff --git a/src/mongo/db/commands/parameters.cpp b/src/mongo/db/commands/parameters.cpp index 0cfb0f6dcf2..0829fcd0e32 100644 --- a/src/mongo/db/commands/parameters.cpp +++ b/src/mongo/db/commands/parameters.cpp @@ -238,6 +238,11 @@ public: h += "{ getParameter:'*' } or { getParameter:{allParameters: true} } to get everything\n"; return h; } + + bool allowedWithSecurityToken() const final { + return true; + } + bool run(OperationContext* opCtx, const DatabaseName& dbName, const BSONObj& cmdObj, diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp index b2fb0341568..13037602698 100644 --- a/src/mongo/db/commands/pipeline_command.cpp +++ b/src/mongo/db/commands/pipeline_command.cpp @@ -35,8 +35,8 @@ #include "mongo/db/auth/authorization_checks.h" #include "mongo/db/auth/authorization_session.h" #include "mongo/db/catalog/create_collection.h" -#include "mongo/db/catalog/virtual_collection_options.h" #include "mongo/db/commands.h" +#include "mongo/db/commands/external_data_source_scope_guard.h" #include "mongo/db/commands/run_aggregate.h" #include "mongo/db/namespace_string.h" #include "mongo/db/pipeline/aggregate_command_gen.h" @@ -46,7 +46,7 @@ #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/query_knobs_gen.h" #include "mongo/idl/idl_parser.h" -#include "mongo/stdx/unordered_set.h" +#include "mongo/util/assert_util.h" #include "mongo/util/database_name_util.h" namespace mongo { @@ -202,18 +202,20 @@ public: CommandHelpers::handleMarkKillOnClientDisconnect( opCtx, !Pipeline::aggHasWriteStage(_request.body)); - // TODO SERVER-69687 Create a virtual collection per each used external data source. - for (auto&& [collName, dataSources] : _usedExternalDataSources) { - VirtualCollectionOptions vopts(dataSources); - } - + // Create virtual collections and drop them when aggregate command is done. Conceptually + // ownership of virtual collections are moved to runAggregate() function together with + // 'dropVcollGuard' so that it can clean up virtual collections when it's done with + // them. ExternalDataSourceScopeGuard will take care of the situation when any + // collection could not be created. + ExternalDataSourceScopeGuard dropVcollGuard(opCtx, _usedExternalDataSources); uassertStatusOK(runAggregate(opCtx, _aggregationRequest.getNamespace(), _aggregationRequest, _liteParsedPipeline, _request.body, _privileges, - reply)); + reply, + std::move(dropVcollGuard))); // The aggregate command's response is unstable when 'explain' or 'exchange' fields are // set. @@ -231,14 +233,16 @@ public: void explain(OperationContext* opCtx, ExplainOptions::Verbosity verbosity, rpc::ReplyBuilderInterface* result) override { - + // See run() method for details. + ExternalDataSourceScopeGuard dropVcollGuard(opCtx, _usedExternalDataSources); uassertStatusOK(runAggregate(opCtx, _aggregationRequest.getNamespace(), _aggregationRequest, _liteParsedPipeline, _request.body, _privileges, - result)); + result, + std::move(dropVcollGuard))); } void doCheckAuthorization(OperationContext* opCtx) const override { diff --git a/src/mongo/db/commands/profile_common.cpp b/src/mongo/db/commands/profile_common.cpp index 54223b8f5a7..ba546e62067 100644 --- a/src/mongo/db/commands/profile_common.cpp +++ b/src/mongo/db/commands/profile_common.cpp @@ -133,6 +133,7 @@ bool ProfileCmdBase::run(OperationContext* opCtx, newState.append("filter"_sd, newSettings.filter->serialize()); } attrs.add("to", newState.obj()); + attrs.add("db", dbName.db()); LOGV2(48742, "Profiler settings changed", attrs); } diff --git a/src/mongo/db/commands/run_aggregate.cpp b/src/mongo/db/commands/run_aggregate.cpp index eb32a6c220b..cb15cf82247 100644 --- a/src/mongo/db/commands/run_aggregate.cpp +++ b/src/mongo/db/commands/run_aggregate.cpp @@ -42,6 +42,7 @@ #include "mongo/db/change_stream_change_collection_manager.h" #include "mongo/db/change_stream_pre_images_collection_manager.h" #include "mongo/db/change_stream_serverless_helpers.h" +#include "mongo/db/commands/external_data_source_scope_guard.h" #include "mongo/db/curop.h" #include "mongo/db/cursor_manager.h" #include "mongo/db/db_raii.h" @@ -81,6 +82,7 @@ #include "mongo/db/repl/speculative_majority_read_info.h" #include "mongo/db/s/collection_sharding_state.h" #include "mongo/db/s/operation_sharding_state.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/s/sharding_state.h" #include "mongo/db/service_context.h" #include "mongo/db/stats/resource_consumption_metrics.h" @@ -659,7 +661,7 @@ Status runAggregate(OperationContext* opCtx, const BSONObj& cmdObj, const PrivilegeVector& privileges, rpc::ReplyBuilderInterface* result) { - return runAggregate(opCtx, nss, request, {request}, cmdObj, privileges, result); + return runAggregate(opCtx, nss, request, {request}, cmdObj, privileges, result, {}); } Status runAggregate(OperationContext* opCtx, @@ -668,7 +670,8 @@ Status runAggregate(OperationContext* opCtx, const LiteParsedPipeline& liteParsedPipeline, const BSONObj& cmdObj, const PrivilegeVector& privileges, - rpc::ReplyBuilderInterface* result) { + rpc::ReplyBuilderInterface* result, + ExternalDataSourceScopeGuard externalDataSourceGuard) { // Perform some validations on the LiteParsedPipeline and request before continuing with the // aggregation command. @@ -945,6 +948,15 @@ Status runAggregate(OperationContext* opCtx, } } + if (analyze_shard_key::supportsPersistingSampledQueries() && request.getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addAggregateQuery(*request.getSampleId(), + expCtx->ns, + pipeline->getInitialQuery(), + expCtx->getCollatorBSON()) + .getAsync([](auto) {}); + } + // If the aggregate command supports encrypted collections, do rewrites of the pipeline to // support querying against encrypted fields. if (shouldDoFLERewrite(request)) { @@ -1021,6 +1033,8 @@ Status runAggregate(OperationContext* opCtx, p.deleteUnderlying(); } }); + auto extDataSrcGuard = + std::make_shared<ExternalDataSourceScopeGuard>(std::move(externalDataSourceGuard)); for (auto&& exec : execs) { ClientCursorParams cursorParams( std::move(exec), @@ -1038,6 +1052,10 @@ Status runAggregate(OperationContext* opCtx, pin->incNBatches(); cursors.emplace_back(pin.getCursor()); + // All cursors share the ownership to 'extDataSrcGuard' and if the last cursor is destroyed, + // 'extDataSrcGuard' is also destroyed and created virtual collections are dropped by the + // destructor of ExternalDataSourceScopeGuard. + ExternalDataSourceScopeGuard::get(pin.getCursor()) = extDataSrcGuard; pins.emplace_back(std::move(pin)); } diff --git a/src/mongo/db/commands/run_aggregate.h b/src/mongo/db/commands/run_aggregate.h index 0bb86ac91b0..33c03e80cdb 100644 --- a/src/mongo/db/commands/run_aggregate.h +++ b/src/mongo/db/commands/run_aggregate.h @@ -32,6 +32,7 @@ #include "mongo/bson/bsonobj.h" #include "mongo/bson/bsonobjbuilder.h" #include "mongo/db/auth/privilege.h" +#include "mongo/db/commands/external_data_source_scope_guard.h" #include "mongo/db/namespace_string.h" #include "mongo/db/operation_context.h" #include "mongo/db/pipeline/aggregate_command_gen.h" @@ -57,7 +58,8 @@ Status runAggregate(OperationContext* opCtx, const LiteParsedPipeline& liteParsedPipeline, const BSONObj& cmdObj, const PrivilegeVector& privileges, - rpc::ReplyBuilderInterface* result); + rpc::ReplyBuilderInterface* result, + ExternalDataSourceScopeGuard externalDataSourceGuard); /** * Convenience version that internally constructs the LiteParsedPipeline. diff --git a/src/mongo/db/commands/server_status_command.cpp b/src/mongo/db/commands/server_status_command.cpp index 62433e57e62..521e4a49010 100644 --- a/src/mongo/db/commands/server_status_command.cpp +++ b/src/mongo/db/commands/server_status_command.cpp @@ -78,6 +78,10 @@ public: return Status::OK(); } + bool allowedWithSecurityToken() const final { + return true; + } + bool run(OperationContext* opCtx, const DatabaseName& dbName, const BSONObj& cmdObj, diff --git a/src/mongo/db/commands/tenant_migration_recipient_cmds.idl b/src/mongo/db/commands/tenant_migration_recipient_cmds.idl index 96bfcd775cd..c8f58df0c0c 100644 --- a/src/mongo/db/commands/tenant_migration_recipient_cmds.idl +++ b/src/mongo/db/commands/tenant_migration_recipient_cmds.idl @@ -97,6 +97,7 @@ commands: strict: true namespace: ignored api_version: "" + reply_type: recipientSyncDataResponse inline_chained_structs: true chained_structs: MigrationRecipientCommonData: MigrationRecipientCommonData @@ -144,6 +145,7 @@ commands: recipientForgetMigration: description: "Parser for the 'recipientForgetMigration' command." command_name: recipientForgetMigration + reply_type: OkReply strict: true namespace: ignored api_version: "" diff --git a/src/mongo/db/concurrency/lock_state.h b/src/mongo/db/concurrency/lock_state.h index 9d37c7e0c50..9dd80b25110 100644 --- a/src/mongo/db/concurrency/lock_state.h +++ b/src/mongo/db/concurrency/lock_state.h @@ -421,16 +421,15 @@ public: }; /** - * RAII-style class to set the priority for the ticket acquisition mechanism when acquiring a global + * RAII-style class to set the priority for the ticket admission mechanism when acquiring a global * lock. */ -class SetTicketAquisitionPriorityForLock { +class SetAdmissionPriorityForLock { public: - SetTicketAquisitionPriorityForLock(const SetTicketAquisitionPriorityForLock&) = delete; - SetTicketAquisitionPriorityForLock& operator=(const SetTicketAquisitionPriorityForLock&) = - delete; - explicit SetTicketAquisitionPriorityForLock(OperationContext* opCtx, - AdmissionContext::Priority priority) + SetAdmissionPriorityForLock(const SetAdmissionPriorityForLock&) = delete; + SetAdmissionPriorityForLock& operator=(const SetAdmissionPriorityForLock&) = delete; + explicit SetAdmissionPriorityForLock(OperationContext* opCtx, + AdmissionContext::Priority priority) : _opCtx(opCtx), _originalPriority(opCtx->lockState()->getAdmissionPriority()) { uassert(ErrorCodes::IllegalOperation, "It is illegal for an operation to demote a high priority to a lower priority " @@ -440,7 +439,7 @@ public: _opCtx->lockState()->setAdmissionPriority(priority); } - ~SetTicketAquisitionPriorityForLock() { + ~SetAdmissionPriorityForLock() { _opCtx->lockState()->setAdmissionPriority(_originalPriority); } diff --git a/src/mongo/db/concurrency/lock_state_test.cpp b/src/mongo/db/concurrency/lock_state_test.cpp index 607c2c076f7..0062fe990ae 100644 --- a/src/mongo/db/concurrency/lock_state_test.cpp +++ b/src/mongo/db/concurrency/lock_state_test.cpp @@ -1244,20 +1244,19 @@ TEST_F(LockerImplTest, SetTicketAcquisitionForLockRAIIType) { ASSERT_TRUE(opCtx->lockState()->shouldWaitForTicket()); { - SetTicketAquisitionPriorityForLock setTicketAquisition( - opCtx.get(), AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock setTicketAquisition(opCtx.get(), + AdmissionContext::Priority::kImmediate); ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket()); } ASSERT_TRUE(opCtx->lockState()->shouldWaitForTicket()); - // If ticket acquisitions are disabled on the lock state, the RAII type has no effect. opCtx->lockState()->setAdmissionPriority(AdmissionContext::Priority::kImmediate); ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket()); { - SetTicketAquisitionPriorityForLock setTicketAquisition( - opCtx.get(), AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock setTicketAquisition(opCtx.get(), + AdmissionContext::Priority::kImmediate); ASSERT_FALSE(opCtx->lockState()->shouldWaitForTicket()); } diff --git a/src/mongo/db/curop.cpp b/src/mongo/db/curop.cpp index 73ddcb6827e..9e6b89c1af4 100644 --- a/src/mongo/db/curop.cpp +++ b/src/mongo/db/curop.cpp @@ -278,7 +278,7 @@ void CurOp::setGenericCursor_inlock(GenericCursor gc) { void CurOp::_finishInit(OperationContext* opCtx, CurOpStack* stack) { _stack = stack; - _tickSource = SystemTickSource::get(); + _tickSource = globalSystemTickSource(); if (opCtx) { _stack->push(opCtx, this); diff --git a/src/mongo/db/db_raii.cpp b/src/mongo/db/db_raii.cpp index 5976628a7f1..7bc3b31f24b 100644 --- a/src/mongo/db/db_raii.cpp +++ b/src/mongo/db/db_raii.cpp @@ -38,6 +38,7 @@ #include "mongo/db/s/collection_sharding_state.h" #include "mongo/db/s/operation_sharding_state.h" #include "mongo/db/storage/snapshot_helper.h" +#include "mongo/db/storage/storage_parameters_gen.h" #include "mongo/logv2/log.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage @@ -641,7 +642,7 @@ AutoGetCollectionForRead::AutoGetCollectionForRead(OperationContext* opCtx, } } -AutoGetCollectionForReadLockFree::EmplaceHelper::EmplaceHelper( +AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper::EmplaceHelper( OperationContext* opCtx, CollectionCatalogStasher& catalogStasher, const NamespaceStringOrUUID& nsOrUUID, @@ -653,7 +654,7 @@ AutoGetCollectionForReadLockFree::EmplaceHelper::EmplaceHelper( _options(std::move(options)), _isLockFreeReadSubOperation(isLockFreeReadSubOperation) {} -void AutoGetCollectionForReadLockFree::EmplaceHelper::emplace( +void AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper::emplace( boost::optional<AutoGetCollectionLockFree>& autoColl) const { autoColl.emplace( _opCtx, @@ -706,7 +707,7 @@ void AutoGetCollectionForReadLockFree::EmplaceHelper::emplace( _options); } -AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree( +AutoGetCollectionForReadLockFreeLegacy::AutoGetCollectionForReadLockFreeLegacy( OperationContext* opCtx, const NamespaceStringOrUUID& nsOrUUID, AutoGetCollection::Options options) @@ -749,11 +750,22 @@ AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree( options._secondaryNssOrUUIDs); } -AutoGetCollectionForReadMaybeLockFree::AutoGetCollectionForReadMaybeLockFree( +AutoGetCollectionForReadLockFree::AutoGetCollectionForReadLockFree( OperationContext* opCtx, const NamespaceStringOrUUID& nsOrUUID, AutoGetCollection::Options options) { + if (feature_flags::gPointInTimeCatalogLookups.isEnabledAndIgnoreFCV()) { + _impl.emplace<AutoGetCollectionForReadLockFreePITCatalog>( + opCtx, nsOrUUID, std::move(options)); + } else { + _impl.emplace<AutoGetCollectionForReadLockFreeLegacy>(opCtx, nsOrUUID, std::move(options)); + } +} +AutoGetCollectionForReadMaybeLockFree::AutoGetCollectionForReadMaybeLockFree( + OperationContext* opCtx, + const NamespaceStringOrUUID& nsOrUUID, + AutoGetCollection::Options options) { if (supportsLockFreeRead(opCtx)) { _autoGetLockFree.emplace(opCtx, nsOrUUID, std::move(options)); } else { @@ -808,7 +820,6 @@ AutoGetCollectionForReadCommandBase<AutoGetCollectionForReadType>:: _autoCollForRead.getNss().dbName()), options._deadline, options._secondaryNssOrUUIDs) { - hangBeforeAutoGetShardVersionCheck.executeIf( [&](auto&) { hangBeforeAutoGetShardVersionCheck.pauseWhileSet(opCtx); }, [&](const BSONObj& data) { @@ -1050,7 +1061,7 @@ BlockSecondaryReadsDuringBatchApplication_DONT_USE:: template class AutoGetCollectionForReadBase<AutoGetCollection, EmplaceAutoGetCollectionForRead>; template class AutoGetCollectionForReadCommandBase<AutoGetCollectionForRead>; template class AutoGetCollectionForReadBase<AutoGetCollectionLockFree, - AutoGetCollectionForReadLockFree::EmplaceHelper>; + AutoGetCollectionForReadLockFreeLegacy::EmplaceHelper>; template class AutoGetCollectionForReadCommandBase<AutoGetCollectionForReadLockFree>; } // namespace mongo diff --git a/src/mongo/db/db_raii.h b/src/mongo/db/db_raii.h index 6e5c2874ff5..0ae4b1c24bc 100644 --- a/src/mongo/db/db_raii.h +++ b/src/mongo/db/db_raii.h @@ -33,6 +33,8 @@ #include "mongo/db/catalog_raii.h" #include "mongo/db/stats/top.h" +#include "mongo/stdx/variant.h" +#include "mongo/util/overloaded_visitor.h" #include "mongo/util/timer.h" namespace mongo { @@ -195,24 +197,14 @@ private: * Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks. * Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a * consistent in-memory and on-disk view of the storage catalog. + * + * This implementation does not use the PIT catalog. */ -class AutoGetCollectionForReadLockFree { +class AutoGetCollectionForReadLockFreeLegacy { public: - AutoGetCollectionForReadLockFree(OperationContext* opCtx, - const NamespaceStringOrUUID& nsOrUUID, - AutoGetCollection::Options = {}); - - explicit operator bool() const { - return static_cast<bool>(getCollection()); - } - - const Collection* operator->() const { - return getCollection().get(); - } - - const CollectionPtr& operator*() const { - return getCollection(); - } + AutoGetCollectionForReadLockFreeLegacy(OperationContext* opCtx, + const NamespaceStringOrUUID& nsOrUUID, + AutoGetCollection::Options = {}); const CollectionPtr& getCollection() const { return _autoGetCollectionForReadBase->getCollection(); @@ -226,11 +218,6 @@ public: return _autoGetCollectionForReadBase->getNss(); } - /** - * Indicates whether any namespace in 'secondaryNssOrUUIDs' is a view or sharded. - * - * The secondary namespaces won't be checked if getCollection() returns nullptr. - */ bool isAnySecondaryNamespaceAViewOrSharded() const { return _secondaryNssIsAViewOrSharded; } @@ -272,6 +259,112 @@ private: }; /** + * Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks. + * Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a + * consistent in-memory and on-disk view of the storage catalog. + * + * This implementation uses the point-in-time (PIT) catalog. + */ +class AutoGetCollectionForReadLockFreePITCatalog { +public: + AutoGetCollectionForReadLockFreePITCatalog(OperationContext* opCtx, + const NamespaceStringOrUUID& nsOrUUID, + AutoGetCollection::Options options = {}) + : _impl(opCtx, nsOrUUID, std::move(options)) {} + + const CollectionPtr& getCollection() const { + return _impl.getCollection(); + } + + const ViewDefinition* getView() const { + return _impl.getView(); + } + + const NamespaceString& getNss() const { + return _impl.getNss(); + } + + bool isAnySecondaryNamespaceAViewOrSharded() const { + return _impl.isAnySecondaryNamespaceAViewOrSharded(); + } + +private: + // TODO (SERVER-68271): Replace with new implementation using the PIT catalog. + AutoGetCollectionForReadLockFreeLegacy _impl; +}; + +/** + * Same as AutoGetCollectionForRead above except does not take collection, database or rstl locks. + * Takes the global lock and may take the PBWM, same as AutoGetCollectionForRead. Ensures a + * consistent in-memory and on-disk view of the storage catalog. + */ +class AutoGetCollectionForReadLockFree { +public: + AutoGetCollectionForReadLockFree(OperationContext* opCtx, + const NamespaceStringOrUUID& nsOrUUID, + AutoGetCollection::Options = {}); + + explicit operator bool() const { + return static_cast<bool>(getCollection()); + } + + const Collection* operator->() const { + return getCollection().get(); + } + + const CollectionPtr& operator*() const { + return getCollection(); + } + + const CollectionPtr& getCollection() const { + return stdx::visit( + OverloadedVisitor{ + [](auto&& impl) -> const CollectionPtr& { return impl.getCollection(); }, + [](stdx::monostate) -> const CollectionPtr& { MONGO_UNREACHABLE; }, + }, + _impl); + } + + const ViewDefinition* getView() const { + return stdx::visit( + OverloadedVisitor{[](auto&& impl) { return impl.getView(); }, + [](stdx::monostate) -> const ViewDefinition* { MONGO_UNREACHABLE; }}, + _impl); + } + + const NamespaceString& getNss() const { + return stdx::visit( + OverloadedVisitor{[](auto&& impl) -> const NamespaceString& { return impl.getNss(); }, + [](stdx::monostate) -> const NamespaceString& { MONGO_UNREACHABLE; }}, + _impl); + } + + /** + * Indicates whether any namespace in 'secondaryNssOrUUIDs' is a view or sharded. + * + * The secondary namespaces won't be checked if getCollection() returns nullptr. + */ + bool isAnySecondaryNamespaceAViewOrSharded() const { + return stdx::visit( + OverloadedVisitor{ + [](auto&& impl) { return impl.isAnySecondaryNamespaceAViewOrSharded(); }, + [](stdx::monostate) -> bool { MONGO_UNREACHABLE; }}, + _impl); + } + +private: + // If the gPointInTimeCatalogLookups feature flag is enabled, this will contain an instance of + // AutoGetCollectionForReadLockFreePITCatalog. Otherwise, it will contain an instance of + // AutoGetCollectionForReadLockFreeLegacy. Note that stdx::monostate is required for default + // construction, since these other types are not movable, but after construction, the value + // should never be set to stdx::monostate. + stdx::variant<stdx::monostate, + AutoGetCollectionForReadLockFreeLegacy, + AutoGetCollectionForReadLockFreePITCatalog> + _impl; +}; + +/** * Creates either an AutoGetCollectionForRead or AutoGetCollectionForReadLockFree depending on * whether a lock-free read is supported. */ diff --git a/src/mongo/db/dbdirectclient.h b/src/mongo/db/dbdirectclient.h index 315c17a8bdc..c0f5bf23400 100644 --- a/src/mongo/db/dbdirectclient.h +++ b/src/mongo/db/dbdirectclient.h @@ -55,6 +55,7 @@ public: using DBClientBase::find; using DBClientBase::insert; using DBClientBase::remove; + using DBClientBase::runCommand; using DBClientBase::update; std::unique_ptr<DBClientCursor> find(FindCommandRequest findRequest, diff --git a/src/mongo/db/exec/batched_delete_stage.cpp b/src/mongo/db/exec/batched_delete_stage.cpp index b12a5c0d73b..7c1141e0f99 100644 --- a/src/mongo/db/exec/batched_delete_stage.cpp +++ b/src/mongo/db/exec/batched_delete_stage.cpp @@ -325,7 +325,8 @@ long long BatchedDeleteStage::_commitBatch(WorkingSetID* out, // Start a WUOW with 'groupOplogEntries' which groups a delete batch into a single timestamp // and oplog entry. - WriteUnitOfWork wuow(opCtx(), true /* groupOplogEntries */); + WriteUnitOfWork wuow(opCtx(), + _stagedDeletesBuffer.size() > 1U ? true : false /* groupOplogEntries */); for (; *bufferOffset < _stagedDeletesBuffer.size(); ++*bufferOffset) { if (MONGO_unlikely(throwWriteConflictExceptionInBatchedDeleteStage.shouldFail())) { throwWriteConflictException( diff --git a/src/mongo/db/exec/bucket_unpacker.cpp b/src/mongo/db/exec/bucket_unpacker.cpp index d2ad79f0e03..6a819181df9 100644 --- a/src/mongo/db/exec/bucket_unpacker.cpp +++ b/src/mongo/db/exec/bucket_unpacker.cpp @@ -41,6 +41,7 @@ #include "mongo/db/matcher/expression_parser.h" #include "mongo/db/matcher/expression_tree.h" #include "mongo/db/matcher/extensions_callback_noop.h" +#include "mongo/db/matcher/rewrite_expr.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/timeseries/timeseries_options.h" @@ -130,9 +131,9 @@ std::unique_ptr<MatchExpression> makeOr(std::vector<std::unique_ptr<MatchExpress return std::make_unique<OrMatchExpression>(std::move(nontrivial)); } -std::unique_ptr<MatchExpression> handleIneligible(IneligiblePredicatePolicy policy, - const MatchExpression* matchExpr, - StringData message) { +BucketSpec::BucketPredicate handleIneligible(IneligiblePredicatePolicy policy, + const MatchExpression* matchExpr, + StringData message) { switch (policy) { case IneligiblePredicatePolicy::kError: uasserted( @@ -140,7 +141,7 @@ std::unique_ptr<MatchExpression> handleIneligible(IneligiblePredicatePolicy poli "Error translating non-metadata time-series predicate to operate on buckets: " + message + ": " + matchExpr->serialize().toString()); case IneligiblePredicatePolicy::kIgnore: - return nullptr; + return {}; } MONGO_UNREACHABLE_TASSERT(5916307); } @@ -204,40 +205,32 @@ std::unique_ptr<MatchExpression> createTypeEqualityPredicate( return makeOr(std::move(typeEqualityPredicates)); } -std::unique_ptr<MatchExpression> createComparisonPredicate( - const ComparisonMatchExpressionBase* matchExpr, +boost::optional<StringData> checkComparisonPredicateErrors( + const MatchExpression* matchExpr, + const StringData matchExprPath, + const BSONElement& matchExprData, const BucketSpec& bucketSpec, - int bucketMaxSpanSeconds, - ExpressionContext::CollationMatchesDefault collationMatchesDefault, - boost::intrusive_ptr<ExpressionContext> pExpCtx, - bool haveComputedMetaField, - bool includeMetaField, - bool assumeNoMixedSchemaData, - IneligiblePredicatePolicy policy) { + ExpressionContext::CollationMatchesDefault collationMatchesDefault) { using namespace timeseries; - const auto matchExprPath = matchExpr->path(); - const auto matchExprData = matchExpr->getData(); - // The control field's min and max are chosen using a field-order insensitive comparator, while // MatchExpressions use a comparator that treats field-order as significant. Because of this we // will not perform this optimization on queries with operands of compound types. if (matchExprData.type() == BSONType::Object || matchExprData.type() == BSONType::Array) - return handleIneligible(policy, matchExpr, "operand can't be an object or array"_sd); + return "operand can't be an object or array"_sd; // MatchExpressions have special comparison semantics regarding null, in that {$eq: null} will // match all documents where the field is either null or missing. Because this is different // from both the comparison semantics that InternalExprComparison expressions and the control's // min and max fields use, we will not perform this optimization on queries with null operands. if (matchExprData.type() == BSONType::jstNULL) - return handleIneligible(policy, matchExpr, "can't handle {$eq: null}"_sd); + return "can't handle {$eq: null}"_sd; // The control field's min and max are chosen based on the collation of the collection. If the // query's collation does not match the collection's collation and the query operand is a // string or compound type (skipped above) we will not perform this optimization. if (collationMatchesDefault == ExpressionContext::CollationMatchesDefault::kNo && matchExprData.type() == BSONType::String) { - return handleIneligible( - policy, matchExpr, "can't handle string comparison with a non-default collation"_sd); + return "can't handle string comparison with a non-default collation"_sd; } // This function only handles time and measurement predicates--not metadata. @@ -252,19 +245,45 @@ std::unique_ptr<MatchExpression> createComparisonPredicate( // We must avoid mapping predicates on fields computed via $addFields or a computed $project. if (bucketSpec.fieldIsComputed(matchExprPath.toString())) { - return handleIneligible(policy, matchExpr, "can't handle a computed field"); + return "can't handle a computed field"_sd; } const auto isTimeField = (matchExprPath == bucketSpec.timeField()); if (isTimeField && matchExprData.type() != BSONType::Date) { // Users are not allowed to insert non-date measurements into time field. So this query // would not match anything. We do not need to optimize for this case. - return handleIneligible( - policy, - matchExpr, - "This predicate will never be true, because the time field always contains a Date"); + return "This predicate will never be true, because the time field always contains a Date"_sd; + } + + return boost::none; +} + +std::unique_ptr<MatchExpression> createComparisonPredicate( + const ComparisonMatchExpressionBase* matchExpr, + const BucketSpec& bucketSpec, + int bucketMaxSpanSeconds, + ExpressionContext::CollationMatchesDefault collationMatchesDefault, + boost::intrusive_ptr<ExpressionContext> pExpCtx, + bool haveComputedMetaField, + bool includeMetaField, + bool assumeNoMixedSchemaData, + IneligiblePredicatePolicy policy) { + using namespace timeseries; + const auto matchExprPath = matchExpr->path(); + const auto matchExprData = matchExpr->getData(); + + const auto error = checkComparisonPredicateErrors( + matchExpr, matchExprPath, matchExprData, bucketSpec, collationMatchesDefault); + if (error) { + return handleIneligible(policy, matchExpr, *error).loosePredicate; } + const auto isTimeField = (matchExprPath == bucketSpec.timeField()); + auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath; + const StringData minPathStringData(minPath); + auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath; + const StringData maxPathStringData(maxPath); + BSONObj minTime; BSONObj maxTime; if (isTimeField) { @@ -273,11 +292,6 @@ std::unique_ptr<MatchExpression> createComparisonPredicate( maxTime = BSON("" << timeField + Seconds(bucketMaxSpanSeconds)); } - const auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath; - const StringData minPathStringData(minPath); - const auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath; - const StringData maxPathStringData(maxPath); - switch (matchExpr->matchType()) { case MatchExpression::EQ: case MatchExpression::INTERNAL_EXPR_EQ: @@ -481,9 +495,108 @@ std::unique_ptr<MatchExpression> createComparisonPredicate( MONGO_UNREACHABLE_TASSERT(5348303); } +std::unique_ptr<MatchExpression> createTightComparisonPredicate( + const ComparisonMatchExpressionBase* matchExpr, + const BucketSpec& bucketSpec, + ExpressionContext::CollationMatchesDefault collationMatchesDefault) { + using namespace timeseries; + const auto matchExprPath = matchExpr->path(); + const auto matchExprData = matchExpr->getData(); + + const auto error = checkComparisonPredicateErrors( + matchExpr, matchExprPath, matchExprData, bucketSpec, collationMatchesDefault); + if (error) { + return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore, matchExpr, *error) + .loosePredicate; + } + + // We have to disable the tight predicate for the measurement field. There might be missing + // values in the measurements and the control fields ignore them on insertion. So we cannot use + // bucket min and max to determine the property of all events in the bucket. For measurement + // fields, there's a further problem that if the control field is an array, we cannot generate + // the tight predicate because the predicate will be implicitly mapped over the array elements. + if (matchExprPath != bucketSpec.timeField()) { + return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore, + matchExpr, + "can't create tight predicate on non-time field") + .tightPredicate; + } + + auto minPath = std::string{kControlMinFieldNamePrefix} + matchExprPath; + const StringData minPathStringData(minPath); + auto maxPath = std::string{kControlMaxFieldNamePrefix} + matchExprPath; + const StringData maxPathStringData(maxPath); + + switch (matchExpr->matchType()) { + // All events satisfy $eq if bucket min and max both satisfy $eq. + case MatchExpression::EQ: + return makePredicate( + MatchExprPredicate<EqualityMatchExpression>(minPathStringData, matchExprData), + MatchExprPredicate<EqualityMatchExpression>(maxPathStringData, matchExprData)); + case MatchExpression::INTERNAL_EXPR_EQ: + return makePredicate( + MatchExprPredicate<InternalExprEqMatchExpression>(minPathStringData, matchExprData), + MatchExprPredicate<InternalExprEqMatchExpression>(maxPathStringData, + matchExprData)); + + // All events satisfy $gt if bucket min satisfy $gt. + case MatchExpression::GT: + return std::make_unique<GTMatchExpression>(minPathStringData, matchExprData); + case MatchExpression::INTERNAL_EXPR_GT: + return std::make_unique<InternalExprGTMatchExpression>(minPathStringData, + matchExprData); + + // All events satisfy $gte if bucket min satisfy $gte. + case MatchExpression::GTE: + return std::make_unique<GTEMatchExpression>(minPathStringData, matchExprData); + case MatchExpression::INTERNAL_EXPR_GTE: + return std::make_unique<InternalExprGTEMatchExpression>(minPathStringData, + matchExprData); + + // All events satisfy $lt if bucket max satisfy $lt. + case MatchExpression::LT: + return std::make_unique<LTMatchExpression>(maxPathStringData, matchExprData); + case MatchExpression::INTERNAL_EXPR_LT: + return std::make_unique<InternalExprLTMatchExpression>(maxPathStringData, + matchExprData); + + // All events satisfy $lte if bucket max satisfy $lte. + case MatchExpression::LTE: + return std::make_unique<LTEMatchExpression>(maxPathStringData, matchExprData); + case MatchExpression::INTERNAL_EXPR_LTE: + return std::make_unique<InternalExprLTEMatchExpression>(maxPathStringData, + matchExprData); + + default: + MONGO_UNREACHABLE_TASSERT(7026901); + } +} + +std::unique_ptr<MatchExpression> createTightExprComparisonPredicate( + const ExprMatchExpression* matchExpr, + const BucketSpec& bucketSpec, + ExpressionContext::CollationMatchesDefault collationMatchesDefault, + boost::intrusive_ptr<ExpressionContext> pExpCtx) { + using namespace timeseries; + auto rewriteMatchExpr = RewriteExpr::rewrite(matchExpr->getExpression(), pExpCtx->getCollator()) + .releaseMatchExpression(); + if (rewriteMatchExpr && + ComparisonMatchExpressionBase::isInternalExprComparison(rewriteMatchExpr->matchType())) { + auto compareMatchExpr = + checked_cast<const ComparisonMatchExpressionBase*>(rewriteMatchExpr.get()); + return createTightComparisonPredicate( + compareMatchExpr, bucketSpec, collationMatchesDefault); + } + + return handleIneligible(BucketSpec::IneligiblePredicatePolicy::kIgnore, + matchExpr, + "can't handle non-comparison $expr match expression") + .tightPredicate; +} + } // namespace -std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( +BucketSpec::BucketPredicate BucketSpec::createPredicatesOnBucketLevelField( const MatchExpression* matchExpr, const BucketSpec& bucketSpec, int bucketMaxSpanSeconds, @@ -516,39 +629,61 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( if (!includeMetaField) return handleIneligible(policy, matchExpr, "cannot handle an excluded meta field"); - auto result = matchExpr->shallowClone(); + auto looseResult = matchExpr->shallowClone(); expression::applyRenamesToExpression( - result.get(), + looseResult.get(), {{bucketSpec.metaField().value(), timeseries::kBucketMetaFieldName.toString()}}); - return result; + auto tightResult = looseResult->shallowClone(); + return {std::move(looseResult), std::move(tightResult)}; } if (matchExpr->matchType() == MatchExpression::AND) { auto nextAnd = static_cast<const AndMatchExpression*>(matchExpr); - auto andMatchExpr = std::make_unique<AndMatchExpression>(); - + auto looseAndExpression = std::make_unique<AndMatchExpression>(); + auto tightAndExpression = std::make_unique<AndMatchExpression>(); for (size_t i = 0; i < nextAnd->numChildren(); i++) { - if (auto child = createPredicatesOnBucketLevelField(nextAnd->getChild(i), - bucketSpec, - bucketMaxSpanSeconds, - collationMatchesDefault, - pExpCtx, - haveComputedMetaField, - includeMetaField, - assumeNoMixedSchemaData, - policy)) { - andMatchExpr->add(std::move(child)); + auto child = createPredicatesOnBucketLevelField(nextAnd->getChild(i), + bucketSpec, + bucketMaxSpanSeconds, + collationMatchesDefault, + pExpCtx, + haveComputedMetaField, + includeMetaField, + assumeNoMixedSchemaData, + policy); + if (child.loosePredicate) { + looseAndExpression->add(std::move(child.loosePredicate)); + } + + if (tightAndExpression && child.tightPredicate) { + tightAndExpression->add(std::move(child.tightPredicate)); + } else { + // For tight expression, null means always false, we can short circuit here. + tightAndExpression = nullptr; } } - if (andMatchExpr->numChildren() == 1) { - return andMatchExpr->releaseChild(0); + + // For a loose predicate, if we are unable to generate an expression we can just treat it as + // always true or an empty AND. This is because we are trying to generate a predicate that + // will match the superset of our actual results. + std::unique_ptr<MatchExpression> looseExpression = nullptr; + if (looseAndExpression->numChildren() == 1) { + looseExpression = looseAndExpression->releaseChild(0); + } else if (looseAndExpression->numChildren() > 1) { + looseExpression = std::move(looseAndExpression); } - if (andMatchExpr->numChildren() > 0) { - return andMatchExpr; + + // For a tight predicate, if we are unable to generate an expression we can just treat it as + // always false. This is because we are trying to generate a predicate that will match the + // subset of our actual results. + std::unique_ptr<MatchExpression> tightExpression = nullptr; + if (tightAndExpression && tightAndExpression->numChildren() == 1) { + tightExpression = tightAndExpression->releaseChild(0); + } else { + tightExpression = std::move(tightAndExpression); } - // No error message here: an empty AND is valid. - return nullptr; + return {std::move(looseExpression), std::move(tightExpression)}; } else if (matchExpr->matchType() == MatchExpression::OR) { // Given {$or: [A, B]}, suppose A, B can be pushed down as A', B'. // If an event matches {$or: [A, B]} then either: @@ -556,9 +691,9 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( // - it matches B, which means any bucket containing it matches B' // So {$or: [A', B']} will capture all the buckets we need to satisfy {$or: [A, B]}. auto nextOr = static_cast<const OrMatchExpression*>(matchExpr); - auto result = std::make_unique<OrMatchExpression>(); + auto looseOrExpression = std::make_unique<OrMatchExpression>(); + auto tightOrExpression = std::make_unique<OrMatchExpression>(); - bool alwaysTrue = false; for (size_t i = 0; i < nextOr->numChildren(); i++) { auto child = createPredicatesOnBucketLevelField(nextOr->getChild(i), bucketSpec, @@ -569,41 +704,76 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( includeMetaField, assumeNoMixedSchemaData, policy); - if (child) { - result->add(std::move(child)); + if (looseOrExpression && child.loosePredicate) { + looseOrExpression->add(std::move(child.loosePredicate)); } else { - // Since this argument is always-true, the entire OR is always-true. - alwaysTrue = true; + // For loose expression, null means always true, we can short circuit here. + looseOrExpression = nullptr; + } - // Only short circuit if we're uninterested in reporting errors. - if (policy == IneligiblePredicatePolicy::kIgnore) - break; + // For tight predicate, we give a tighter bound so that all events in the bucket + // either all matches A or all matches B. + if (child.tightPredicate) { + tightOrExpression->add(std::move(child.tightPredicate)); } } - if (alwaysTrue) - return nullptr; - // No special case for an empty OR: returning nullptr would be incorrect because it - // means 'always-true', here. - return result; + // For a loose predicate, if we are unable to generate an expression we can just treat it as + // always true. This is because we are trying to generate a predicate that will match the + // superset of our actual results. + std::unique_ptr<MatchExpression> looseExpression = nullptr; + if (looseOrExpression && looseOrExpression->numChildren() == 1) { + looseExpression = looseOrExpression->releaseChild(0); + } else { + looseExpression = std::move(looseOrExpression); + } + + // For a tight predicate, if we are unable to generate an expression we can just treat it as + // always false or an empty OR. This is because we are trying to generate a predicate that + // will match the subset of our actual results. + std::unique_ptr<MatchExpression> tightExpression = nullptr; + if (tightOrExpression->numChildren() == 1) { + tightExpression = tightOrExpression->releaseChild(0); + } else if (tightOrExpression->numChildren() > 1) { + tightExpression = std::move(tightOrExpression); + } + + return {std::move(looseExpression), std::move(tightExpression)}; } else if (ComparisonMatchExpression::isComparisonMatchExpression(matchExpr) || ComparisonMatchExpressionBase::isInternalExprComparison(matchExpr->matchType())) { - return createComparisonPredicate( - checked_cast<const ComparisonMatchExpressionBase*>(matchExpr), - bucketSpec, - bucketMaxSpanSeconds, - collationMatchesDefault, - pExpCtx, - haveComputedMetaField, - includeMetaField, - assumeNoMixedSchemaData, - policy); + return { + createComparisonPredicate(checked_cast<const ComparisonMatchExpressionBase*>(matchExpr), + bucketSpec, + bucketMaxSpanSeconds, + collationMatchesDefault, + pExpCtx, + haveComputedMetaField, + includeMetaField, + assumeNoMixedSchemaData, + policy), + createTightComparisonPredicate( + checked_cast<const ComparisonMatchExpressionBase*>(matchExpr), + bucketSpec, + collationMatchesDefault)}; + } else if (matchExpr->matchType() == MatchExpression::EXPRESSION) { + return { + // The loose predicate will be pushed before the unpacking which will be inspected by + // the + // query planner. Since the classic planner doesn't handle the $expr expression, we + // don't + // generate the loose predicate. + nullptr, + createTightExprComparisonPredicate(checked_cast<const ExprMatchExpression*>(matchExpr), + bucketSpec, + collationMatchesDefault, + pExpCtx)}; } else if (matchExpr->matchType() == MatchExpression::GEO) { auto& geoExpr = static_cast<const GeoMatchExpression*>(matchExpr)->getGeoExpression(); if (geoExpr.getPred() == GeoExpression::WITHIN || geoExpr.getPred() == GeoExpression::INTERSECT) { - return std::make_unique<InternalBucketGeoWithinMatchExpression>( - geoExpr.getGeometryPtr(), geoExpr.getField()); + return {std::make_unique<InternalBucketGeoWithinMatchExpression>( + geoExpr.getGeometryPtr(), geoExpr.getField()), + nullptr}; } } else if (matchExpr->matchType() == MatchExpression::EXISTS) { if (assumeNoMixedSchemaData) { @@ -613,7 +783,7 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( std::string{timeseries::kControlMinFieldNamePrefix} + matchExpr->path()))); result->add(std::make_unique<ExistsMatchExpression>(StringData( std::string{timeseries::kControlMaxFieldNamePrefix} + matchExpr->path()))); - return result; + return {std::move(result), nullptr}; } else { // At time of writing, we only pass 'kError' when creating a partial index, and // we know the collection will have no mixed-schema buckets by the time the index is @@ -622,7 +792,7 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( "Can't push down {$exists: true} when the collection may have mixed-schema " "buckets.", policy != IneligiblePredicatePolicy::kError); - return nullptr; + return {}; } } else if (matchExpr->matchType() == MatchExpression::MATCH_IN) { // {a: {$in: [X, Y]}} is equivalent to {$or: [ {a: X}, {a: Y} ]}. @@ -664,11 +834,11 @@ std::unique_ptr<MatchExpression> BucketSpec::createPredicatesOnBucketLevelField( } } if (alwaysTrue) - return nullptr; + return {}; // As above, no special case for an empty IN: returning nullptr would be incorrect because // it means 'always-true', here. - return result; + return {std::move(result), nullptr}; } return handleIneligible(policy, matchExpr, "can't handle this predicate"); } @@ -713,9 +883,9 @@ std::pair<bool, BSONObj> BucketSpec::pushdownPredicate( BucketSpec{ tsOptions.getTimeField().toString(), metaField.map([](StringData s) { return s.toString(); }), - // Since we are operating on a collection, not a query-result, there are no - // inclusion/exclusion projections we need to apply to the buckets before - // unpacking. + // Since we are operating on a collection, not a query-result, + // there are no inclusion/exclusion projections we need to apply + // to the buckets before unpacking. {}, // And there are no computed projections. {}, @@ -727,6 +897,7 @@ std::pair<bool, BSONObj> BucketSpec::pushdownPredicate( includeMetaField, assumeNoMixedSchemaData, policy) + .loosePredicate : nullptr; BSONObjBuilder result; @@ -1230,9 +1401,10 @@ Document BucketUnpacker::extractSingleMeasurement(int j) { return measurement.freeze(); } -void BucketUnpacker::reset(BSONObj&& bucket) { +void BucketUnpacker::reset(BSONObj&& bucket, bool bucketMatchedQuery) { _unpackingImpl.reset(); _bucket = std::move(bucket); + _bucketMatchedQuery = bucketMatchedQuery; uassert(5346510, "An empty bucket cannot be unpacked", !_bucket.isEmpty()); auto&& dataRegion = _bucket.getField(timeseries::kBucketDataFieldName).Obj(); diff --git a/src/mongo/db/exec/bucket_unpacker.h b/src/mongo/db/exec/bucket_unpacker.h index 8f7f8210618..3a9813eb87a 100644 --- a/src/mongo/db/exec/bucket_unpacker.h +++ b/src/mongo/db/exec/bucket_unpacker.h @@ -127,14 +127,29 @@ public: kError, }; + struct BucketPredicate { + // A loose predicate is a predicate which returns true when any measures of a bucket + // matches. + std::unique_ptr<MatchExpression> loosePredicate; + + // A tight predicate is a predicate which returns true when all measures of a bucket + // matches. + std::unique_ptr<MatchExpression> tightPredicate; + }; + /** * Takes a predicate after $_internalUnpackBucket on a bucketed field as an argument and - * attempts to map it to a new predicate on the 'control' field. For example, the predicate - * {a: {$gt: 5}} will generate the predicate {control.max.a: {$_internalExprGt: 5}}, which will - * be added before the $_internalUnpackBucket stage. + * attempts to map it to new predicates on the 'control' field. There will be a 'loose' + * predicate that will match if some of the event field matches, also a 'tight' predicate that + * will match if all of the event field matches. For example, the event level predicate {a: + * {$gt: 5}} will generate the loose predicate {control.max.a: {$_internalExprGt: 5}}, and the + * tight predicate {control.min.a: {$_internalExprGt: 5}}. The loose predicate will be added + * before the + * $_internalUnpackBucket stage to filter out buckets with no match. The tight predicate will + * be used to evaluate predicate on bucket level to avoid unnecessary event level evaluation. * - * If the original predicate is on the bucket's timeField we may also create a new predicate - * on the '_id' field to assist in index utilization. For example, the predicate + * If the original predicate is on the bucket's timeField we may also create a new loose + * predicate on the '_id' field to assist in index utilization. For example, the predicate * {time: {$lt: new Date(...)}} will generate the following predicate: * {$and: [ * {_id: {$lt: ObjectId(...)}}, @@ -147,7 +162,7 @@ public: * When using IneligiblePredicatePolicy::kIgnore, if the predicate can't be pushed down, it * returns null. When using IneligiblePredicatePolicy::kError it raises a user error. */ - static std::unique_ptr<MatchExpression> createPredicatesOnBucketLevelField( + static BucketPredicate createPredicatesOnBucketLevelField( const MatchExpression* matchExpr, const BucketSpec& bucketSpec, int bucketMaxSpanSeconds, @@ -269,7 +284,7 @@ public: /** * This resets the unpacker to prepare to unpack a new bucket described by the given document. */ - void reset(BSONObj&& bucket); + void reset(BSONObj&& bucket, bool bucketMatchedQuery = false); Behavior behavior() const { return _unpackerBehavior; @@ -283,6 +298,10 @@ public: return _bucket; } + bool bucketMatchedQuery() const { + return _bucketMatchedQuery; + } + bool includeMetaField() const { return _includeMetaField; } @@ -350,6 +369,9 @@ private: bool _hasNext = false; + // A flag used to mark that the entire bucket matches the following $match predicate. + bool _bucketMatchedQuery = false; + // A flag used to mark that the timestamp value should be materialized in measurements. bool _includeTimeField{false}; diff --git a/src/mongo/db/exec/collection_scan.cpp b/src/mongo/db/exec/collection_scan.cpp index 1bee034cd0c..6cbacb0c997 100644 --- a/src/mongo/db/exec/collection_scan.cpp +++ b/src/mongo/db/exec/collection_scan.cpp @@ -206,8 +206,6 @@ PlanStage::StageState CollectionScan::doWork(WorkingSetID* out) { << "recordId: " << recordIdToSeek); } } - - return PlanStage::NEED_TIME; } if (_lastSeenId.isNull() && _params.direction == CollectionScanParams::FORWARD && diff --git a/src/mongo/db/exec/exclusion_projection_executor.cpp b/src/mongo/db/exec/exclusion_projection_executor.cpp index 9823ed1b125..ad7d7ca9ddb 100644 --- a/src/mongo/db/exec/exclusion_projection_executor.cpp +++ b/src/mongo/db/exec/exclusion_projection_executor.cpp @@ -38,9 +38,10 @@ std::pair<BSONObj, bool> ExclusionNode::extractProjectOnFieldAndRename(const Str BSONObjBuilder extractedExclusion; // Check for a projection directly on 'oldName'. For example, {oldName: 0}. - if (auto it = _projectedFields.find(oldName); it != _projectedFields.end()) { + if (auto it = _projectedFieldsSet.find(oldName); it != _projectedFieldsSet.end()) { extractedExclusion.append(newName, false); - _projectedFields.erase(it); + _projectedFieldsSet.erase(it); + _projectedFields.remove(std::string(oldName)); } // Check for a projection on subfields of 'oldName'. For example, {oldName: {a: 0, b: 0}}. diff --git a/src/mongo/db/exec/inclusion_projection_executor.cpp b/src/mongo/db/exec/inclusion_projection_executor.cpp index 84117710063..30e06a76d82 100644 --- a/src/mongo/db/exec/inclusion_projection_executor.cpp +++ b/src/mongo/db/exec/inclusion_projection_executor.cpp @@ -68,7 +68,7 @@ void FastPathEligibleInclusionNode::_applyProjections(BSONObj bson, BSONObjBuild const auto bsonElement{it.next()}; const auto fieldName{bsonElement.fieldNameStringData()}; - if (_projectedFields.find(fieldName) != _projectedFields.end()) { + if (_projectedFieldsSet.find(fieldName) != _projectedFieldsSet.end()) { bob->append(bsonElement); --nFieldsNeeded; } else if (auto childIt = _children.find(fieldName); childIt != _children.end()) { @@ -169,7 +169,8 @@ std::pair<BSONObj, bool> InclusionNode::extractComputedProjectionsInProject( if (std::get<2>(expressionSpec)) { // Replace the expression with an inclusion projected field. - _projectedFields.insert(fieldName); + auto it = _projectedFields.insert(_projectedFields.end(), fieldName); + _projectedFieldsSet.insert(StringData(*it)); _expressions.erase(fieldName); // Only computed projections at the beginning of the list were marked to become // projected fields. The new projected field is at the beginning of the diff --git a/src/mongo/db/exec/projection_node.cpp b/src/mongo/db/exec/projection_node.cpp index 00fcf70946f..bbd2ebcac6b 100644 --- a/src/mongo/db/exec/projection_node.cpp +++ b/src/mongo/db/exec/projection_node.cpp @@ -48,7 +48,8 @@ void ProjectionNode::addProjectionForPath(const FieldPath& path) { void ProjectionNode::_addProjectionForPath(const FieldPath& path) { makeOptimizationsStale(); if (path.getPathLength() == 1) { - _projectedFields.insert(path.fullPath()); + auto it = _projectedFields.insert(_projectedFields.end(), path.fullPath()); + _projectedFieldsSet.insert(StringData(*it)); return; } // FieldPath can't be empty, so it is safe to obtain the first path component here. @@ -141,7 +142,7 @@ void ProjectionNode::applyProjections(const Document& inputDoc, MutableDocument* while (it.more()) { auto fieldName = it.fieldName(); - if (_projectedFields.find(fieldName) != _projectedFields.end()) { + if (_projectedFieldsSet.find(fieldName) != _projectedFieldsSet.end()) { outputProjectedField( fieldName, applyLeafProjectionToValue(it.next().second), outputDoc); ++projectedFields; @@ -280,7 +281,7 @@ void ProjectionNode::serialize(boost::optional<ExplainOptions::Verbosity> explai const bool projVal = !applyLeafProjectionToValue(Value(true)).missing(); // Always put "_id" first if it was projected (implicitly or explicitly). - if (_projectedFields.find("_id") != _projectedFields.end()) { + if (_projectedFieldsSet.find("_id") != _projectedFieldsSet.end()) { output->addField("_id", Value(projVal)); } diff --git a/src/mongo/db/exec/projection_node.h b/src/mongo/db/exec/projection_node.h index 82cae8b145b..61e9b98b89d 100644 --- a/src/mongo/db/exec/projection_node.h +++ b/src/mongo/db/exec/projection_node.h @@ -29,6 +29,8 @@ #pragma once +#include <list> + #include "mongo/db/exec/projection_executor.h" #include "mongo/db/query/projection_policies.h" @@ -176,7 +178,14 @@ protected: StringMap<std::unique_ptr<ProjectionNode>> _children; StringMap<boost::intrusive_ptr<Expression>> _expressions; - StringSet _projectedFields; + + // List of the projected fields in the order in which they were specified. + std::list<std::string> _projectedFields; + + // Set of projected fields. Note that the _projectedFields list actually owns the strings, and + // this StringDataSet simply holds views of those strings. + StringDataSet _projectedFieldsSet; + ProjectionPolicies _policies; std::string _pathToNode; diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript index 05cf5d9a505..7fc2ba50718 100644 --- a/src/mongo/db/exec/sbe/SConscript +++ b/src/mongo/db/exec/sbe/SConscript @@ -240,6 +240,7 @@ env.Library( ], LIBDEPS=[ "$BUILD_DIR/mongo/db/auth/authmocks", + "$BUILD_DIR/mongo/db/query/optimizer/unit_test_utils", '$BUILD_DIR/mongo/db/query/query_test_service_context', '$BUILD_DIR/mongo/db/query_exec', '$BUILD_DIR/mongo/db/service_context_test_fixture', diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp index 67868053dd9..55767ac4b3f 100644 --- a/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test.cpp @@ -31,12 +31,12 @@ #include "mongo/db/exec/sbe/abt/sbe_abt_test_util.h" #include "mongo/db/pipeline/abt/document_source_visitor.h" #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/metadata_factory.h" #include "mongo/db/query/optimizer/opt_phase_manager.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" #include "mongo/db/query/optimizer/rewrites/path_lower.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/unittest/unittest.h" namespace mongo::optimizer { @@ -379,8 +379,10 @@ TEST_F(NodeSBE, Lower1) { true /*hasRID*/), prefixId); - OptPhaseManager phaseManager( - OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager(OptPhaseManager::getAllRewritesSet(), + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); phaseManager.optimize(tree); auto env = VariableEnvironment::build(tree); @@ -464,15 +466,10 @@ TEST_F(NodeSBE, RequireRID) { true /*hasRID*/), prefixId); - OptPhaseManager phaseManager(OptPhaseManager::getAllRewritesSet(), - prefixId, - true /*requireRID*/, - {{{"test", createScanDef({}, {})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManagerRequireRID(OptPhaseManager::getAllRewritesSet(), + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); phaseManager.optimize(tree); auto env = VariableEnvironment::build(tree); diff --git a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp index 7cc3dcc60ea..3f9c8517878 100644 --- a/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp +++ b/src/mongo/db/exec/sbe/abt/sbe_abt_test_util.cpp @@ -37,6 +37,7 @@ #include "mongo/db/query/cqf_command_utils.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/opt_phase_manager.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/db/query/plan_executor.h" #include "mongo/db/query/plan_executor_factory.h" #include "mongo/logv2/log.h" @@ -114,8 +115,9 @@ std::vector<BSONObj> runSBEAST(OperationContext* opCtx, OPTIMIZER_DEBUG_LOG( 6264807, 5, "SBE translated ABT", "explain"_attr = ExplainGenerator::explainV2(tree)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( OptPhaseManager::getAllRewritesSet(), prefixId, {{}}, DebugInfo::kDefaultForTests); + phaseManager.optimize(tree); OPTIMIZER_DEBUG_LOG( diff --git a/src/mongo/db/exec/update_stage.cpp b/src/mongo/db/exec/update_stage.cpp index ca7d4fb5ef5..a7ed460eb2f 100644 --- a/src/mongo/db/exec/update_stage.cpp +++ b/src/mongo/db/exec/update_stage.cpp @@ -239,6 +239,7 @@ BSONObj UpdateStage::transformAndUpdate(const Snapshotted<BSONObj>& oldObj, if (!request->explain()) { args.stmtIds = request->getStmtIds(); + args.sampleId = request->getSampleId(); args.update = logObj; if (_isUserInitiatedWrite) { auto scopedCss = CollectionShardingState::assertCollectionLockedAndAcquire( diff --git a/src/mongo/db/index_builds_coordinator.cpp b/src/mongo/db/index_builds_coordinator.cpp index bb4f0239bf1..1c4e7ff8e3c 100644 --- a/src/mongo/db/index_builds_coordinator.cpp +++ b/src/mongo/db/index_builds_coordinator.cpp @@ -2615,7 +2615,7 @@ void IndexBuildsCoordinator::_scanCollectionAndInsertSortedKeysIntoIndex( // impact on user operations. Other steps of the index builds such as the draining phase have // normal priority because index builds are required to eventually catch-up with concurrent // writers. Otherwise we risk never finishing the index build. - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); { indexBuildsSSS.scanCollection.addAndFetch(1); @@ -2650,7 +2650,7 @@ void IndexBuildsCoordinator::_insertSortedKeysIntoIndexForResume( // impact on user operations. Other steps of the index builds such as the draining phase have // normal priority because index builds are required to eventually catch-up with concurrent // writers. Otherwise we risk never finishing the index build. - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); { Lock::DBLock autoDb(opCtx, replState->dbName, MODE_IX); const NamespaceStringOrUUID dbAndUUID(replState->dbName, replState->collectionUUID); diff --git a/src/mongo/db/mirror_maestro.cpp b/src/mongo/db/mirror_maestro.cpp index 172b7ce6bda..cf801a8947a 100644 --- a/src/mongo/db/mirror_maestro.cpp +++ b/src/mongo/db/mirror_maestro.cpp @@ -32,6 +32,7 @@ #include "mongo/db/mirror_maestro.h" +#include "mongo/rpc/get_status_from_command_result.h" #include <cmath> #include <cstdlib> #include <utility> @@ -72,9 +73,10 @@ constexpr auto kMirroredReadsParamName = "mirrorReads"_sd; constexpr auto kMirroredReadsSeenKey = "seen"_sd; constexpr auto kMirroredReadsSentKey = "sent"_sd; -constexpr auto kMirroredReadsReceivedKey = "received"_sd; +constexpr auto kMirroredReadsProcessedAsSecondaryKey = "processedAsSecondary"_sd; constexpr auto kMirroredReadsResolvedKey = "resolved"_sd; constexpr auto kMirroredReadsResolvedBreakdownKey = "resolvedBreakdown"_sd; +constexpr auto kMirroredReadsSucceededKey = "succeeded"_sd; constexpr auto kMirroredReadsPendingKey = "pending"_sd; MONGO_FAIL_POINT_DEFINE(mirrorMaestroExpectsResponse); @@ -185,12 +187,13 @@ public: BSONObjBuilder section; section.append(kMirroredReadsSeenKey, seen.loadRelaxed()); section.append(kMirroredReadsSentKey, sent.loadRelaxed()); - section.append(kMirroredReadsReceivedKey, received.loadRelaxed()); + section.append(kMirroredReadsProcessedAsSecondaryKey, processedAsSecondary.loadRelaxed()); if (MONGO_unlikely(mirrorMaestroExpectsResponse.shouldFail())) { // We only can see if the command resolved if we got a response section.append(kMirroredReadsResolvedKey, resolved.loadRelaxed()); section.append(kMirroredReadsResolvedBreakdownKey, resolvedBreakdown.toBSON()); + section.append(kMirroredReadsSucceededKey, succeeded.loadRelaxed()); } if (MONGO_unlikely(mirrorMaestroTracksPending.shouldFail())) { section.append(kMirroredReadsPendingKey, pending.loadRelaxed()); @@ -232,13 +235,22 @@ public: ResolvedBreakdownByHost resolvedBreakdown; + // Counts the number of operations (as primary) recognized as "to be mirrored". AtomicWord<CounterT> seen; + // Counts the number of remote requests (for mirroring as primary) sent over the network. AtomicWord<CounterT> sent; + // Counts the number of responses (as primary) from secondaries after mirrored operations. AtomicWord<CounterT> resolved; - // Counts the number of operations that are scheduled to be mirrored, but haven't yet been sent. + // Counts the number of responses (as primary) of successful mirrored operations. Disabled by + // default, hidden behind the mirrorMaestroExpectsResponse fail point. + AtomicWord<CounterT> succeeded; + // Counts the number of operations (as primary) that are scheduled to be mirrored, but + // haven't yet been sent. Disabled by default, hidden behind the mirrorMaestroTracksPending + // fail point. AtomicWord<CounterT> pending; - // Counts the number of mirrored operations received by this node as a secondary. - AtomicWord<CounterT> received; + // Counts the number of mirrored operations processed successfully by this node as a + // secondary. Disabled by default, hidden behind the mirrorMaestroExpectsResponse fail point. + AtomicWord<CounterT> processedAsSecondary; } gMirroredReadsSection; auto parseMirroredReadsParameters(const BSONObj& obj) { @@ -306,7 +318,7 @@ void MirrorMaestro::tryMirrorRequest(OperationContext* opCtx) noexcept { void MirrorMaestro::onReceiveMirroredRead(OperationContext* opCtx) noexcept { const auto& invocation = CommandInvocation::get(opCtx); if (MONGO_unlikely(invocation->isMirrored())) { - gMirroredReadsSection.received.fetchAndAddRelaxed(1); + gMirroredReadsSection.processedAsSecondary.fetchAndAddRelaxed(1); } } @@ -418,6 +430,11 @@ void MirrorMaestroImpl::_mirror(const std::vector<HostAndPort>& hosts, // Count both failed and successful reads as resolved gMirroredReadsSection.resolved.fetchAndAdd(1); gMirroredReadsSection.resolvedBreakdown.onResponseReceived(host); + + if (getStatusFromCommandResult(args.response.data).isOK()) { + gMirroredReadsSection.succeeded.fetchAndAdd(1); + } + LOGV2_DEBUG( 31457, 4, "Response received", "host"_attr = host, "response"_attr = args.response); diff --git a/src/mongo/db/mongod_main.cpp b/src/mongo/db/mongod_main.cpp index 65b6d06f112..127f694b9d5 100644 --- a/src/mongo/db/mongod_main.cpp +++ b/src/mongo/db/mongod_main.cpp @@ -142,6 +142,7 @@ #include "mongo/db/s/op_observer_sharding_impl.h" #include "mongo/db/s/periodic_sharded_index_consistency_checker.h" #include "mongo/db/s/query_analysis_op_observer.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/s/rename_collection_participant_service.h" #include "mongo/db/s/resharding/resharding_coordinator_service.h" #include "mongo/db/s/resharding/resharding_donor_service.h" @@ -856,6 +857,10 @@ ExitCode _initAndListen(ServiceContext* serviceContext, int listenPort) { auto catalog = std::make_unique<StatsCatalog>(serviceContext, std::move(cacheLoader)); StatsCatalog::set(serviceContext, std::move(catalog)); + if (analyze_shard_key::supportsPersistingSampledQueriesIgnoreFCV()) { + analyze_shard_key::QueryAnalysisWriter::get(serviceContext).onStartup(); + } + // MessageServer::run will return when exit code closes its socket and we don't need the // operation context anymore startupOpCtx.reset(); @@ -1174,6 +1179,8 @@ void setUpObservers(ServiceContext* serviceContext) { std::make_unique<OplogWriterTransactionProxy>(std::make_unique<OplogWriterImpl>()))); opObserverRegistry->addObserver(std::make_unique<ShardServerOpObserver>()); opObserverRegistry->addObserver(std::make_unique<ReshardingOpObserver>()); + opObserverRegistry->addObserver( + std::make_unique<analyze_shard_key::QueryAnalysisOpObserver>()); opObserverRegistry->addObserver(std::make_unique<repl::TenantMigrationDonorOpObserver>()); opObserverRegistry->addObserver( std::make_unique<repl::TenantMigrationRecipientOpObserver>()); @@ -1314,6 +1321,11 @@ void shutdownTask(const ShutdownTaskArgs& shutdownArgs) { lsc->joinOnShutDown(); } + if (analyze_shard_key::supportsPersistingSampledQueriesIgnoreFCV()) { + LOGV2(7047303, "Shutting down the QueryAnalysisWriter"); + analyze_shard_key::QueryAnalysisWriter::get(serviceContext).onShutdown(); + } + // Shutdown the TransportLayer so that new connections aren't accepted if (auto tl = serviceContext->getTransportLayer()) { LOGV2_OPTIONS( diff --git a/src/mongo/db/namespace_string.cpp b/src/mongo/db/namespace_string.cpp index a45f99552f5..dab24a3c482 100644 --- a/src/mongo/db/namespace_string.cpp +++ b/src/mongo/db/namespace_string.cpp @@ -195,6 +195,12 @@ const NamespaceString NamespaceString::kGlobalIndexClonerNamespace( const NamespaceString NamespaceString::kConfigQueryAnalyzersNamespace(NamespaceString::kConfigDb, "queryAnalyzers"); +const NamespaceString NamespaceString::kConfigSampledQueriesNamespace(NamespaceString::kConfigDb, + "sampledQueries"); + +const NamespaceString NamespaceString::kConfigSampledQueriesDiffNamespace( + NamespaceString::kConfigDb, "sampledQueriesDiff"); + NamespaceString NamespaceString::parseFromStringExpectTenantIdInMultitenancyMode(StringData ns) { if (!gMultitenancySupport) { return NamespaceString(boost::none, ns); @@ -495,12 +501,12 @@ bool NamespaceString::isSystemStatsCollection() const { } NamespaceString NamespaceString::makeTimeseriesBucketsNamespace() const { - return {db(), kTimeseriesBucketsCollectionPrefix.toString() + coll()}; + return {dbName(), kTimeseriesBucketsCollectionPrefix.toString() + coll()}; } NamespaceString NamespaceString::getTimeseriesViewNamespace() const { invariant(isTimeseriesBucketsCollection(), ns()); - return {db(), coll().substr(kTimeseriesBucketsCollectionPrefix.size())}; + return {dbName(), coll().substr(kTimeseriesBucketsCollectionPrefix.size())}; } bool NamespaceString::isImplicitlyReplicated() const { diff --git a/src/mongo/db/namespace_string.h b/src/mongo/db/namespace_string.h index 904bac9f99d..3c746ad0ddf 100644 --- a/src/mongo/db/namespace_string.h +++ b/src/mongo/db/namespace_string.h @@ -273,6 +273,12 @@ public: // Namespace used for storing query analyzer settings. static const NamespaceString kConfigQueryAnalyzersNamespace; + // Namespace used for storing sampled queries. + static const NamespaceString kConfigSampledQueriesNamespace; + + // Namespace used for storing the diffs for sampled update queries. + static const NamespaceString kConfigSampledQueriesDiffNamespace; + /** * Constructs an empty NamespaceString. */ diff --git a/src/mongo/db/operation_context.cpp b/src/mongo/db/operation_context.cpp index c17d9c2d7b3..24c92e19f6b 100644 --- a/src/mongo/db/operation_context.cpp +++ b/src/mongo/db/operation_context.cpp @@ -80,7 +80,7 @@ OperationContext::OperationContext(Client* client, OperationIdSlot&& opIdSlot) : _client(client), _opId(std::move(opIdSlot)), _elapsedTime(client ? client->getServiceContext()->getTickSource() - : SystemTickSource::get()) {} + : globalSystemTickSource()) {} OperationContext::~OperationContext() { releaseOperationKey(); diff --git a/src/mongo/db/ops/SConscript b/src/mongo/db/ops/SConscript index 1c0819eec70..5f3d894c3db 100644 --- a/src/mongo/db/ops/SConscript +++ b/src/mongo/db/ops/SConscript @@ -48,6 +48,7 @@ env.Library( '$BUILD_DIR/mongo/db/record_id_helpers', '$BUILD_DIR/mongo/db/repl/oplog', '$BUILD_DIR/mongo/db/repl/repl_coordinator_interface', + '$BUILD_DIR/mongo/db/s/query_analysis_writer', '$BUILD_DIR/mongo/db/shard_role', '$BUILD_DIR/mongo/db/stats/counters', '$BUILD_DIR/mongo/db/stats/server_read_concern_write_concern_metrics', diff --git a/src/mongo/db/ops/update_request.h b/src/mongo/db/ops/update_request.h index ed775c9e15e..e635ce7f156 100644 --- a/src/mongo/db/ops/update_request.h +++ b/src/mongo/db/ops/update_request.h @@ -76,7 +76,7 @@ public: }; UpdateRequest(const write_ops::UpdateOpEntry& updateOp = write_ops::UpdateOpEntry()) - : _updateOp(updateOp) {} + : _updateOp(updateOp), _sampleId(updateOp.getSampleId()) {} void setNamespaceString(const NamespaceString& nsString) { _nsString = nsString; @@ -257,6 +257,14 @@ public: return _stmtIds; } + void setSampleId(boost::optional<UUID> sampleId) { + _sampleId = sampleId; + } + + const boost::optional<UUID>& getSampleId() const { + return _sampleId; + } + std::string toString() const { StringBuilder builder; builder << " query: " << getQuery(); @@ -314,6 +322,9 @@ private: // The statement ids of this request. std::vector<StmtId> _stmtIds = {kUninitializedStmtId}; + // The unique sample id for this request if it has been chosen for sampling. + boost::optional<UUID> _sampleId; + // Flags controlling the update. // God bypasses _id checking and index generation. It is only used on behalf of system diff --git a/src/mongo/db/ops/write_ops.idl b/src/mongo/db/ops/write_ops.idl index a0567a957f8..be734dae3a6 100644 --- a/src/mongo/db/ops/write_ops.idl +++ b/src/mongo/db/ops/write_ops.idl @@ -259,6 +259,11 @@ structs: type: object optional: true stability: stable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable DeleteOpEntry: description: "Parser for the entries in the 'deletes' array of a delete command." @@ -285,6 +290,11 @@ structs: type: object optional: true stability: stable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable FindAndModifyLastError: description: "Contains execution details for the findAndModify command" @@ -543,3 +553,8 @@ commands: description: "Indicates whether the operation is a mirrored read" type: optionalBool stability: unstable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable diff --git a/src/mongo/db/ops/write_ops_exec.cpp b/src/mongo/db/ops/write_ops_exec.cpp index f3a37117b00..e98a5931dd8 100644 --- a/src/mongo/db/ops/write_ops_exec.cpp +++ b/src/mongo/db/ops/write_ops_exec.cpp @@ -67,6 +67,7 @@ #include "mongo/db/repl/tenant_migration_decoration.h" #include "mongo/db/s/collection_sharding_state.h" #include "mongo/db/s/operation_sharding_state.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/db/stats/counters.h" #include "mongo/db/stats/server_write_concern_metrics.h" #include "mongo/db/stats/top.h" @@ -677,7 +678,7 @@ WriteResult performInserts(OperationContext* opCtx, bool containsRetry = false; ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); }); - size_t stmtIdIndex = 0; + size_t nextOpIndex = 0; size_t bytesInBatch = 0; std::vector<InsertStatement> batch; const size_t maxBatchSize = internalInsertMaxBatchSize.load(); @@ -685,10 +686,11 @@ WriteResult performInserts(OperationContext* opCtx, batch.reserve(std::min(wholeOp.getDocuments().size(), maxBatchSize)); for (auto&& doc : wholeOp.getDocuments()) { + const auto currentOpIndex = nextOpIndex++; const bool isLastDoc = (&doc == &wholeOp.getDocuments().back()); bool containsDotsAndDollarsField = false; auto fixedDoc = fixDocumentForInsert(opCtx, doc, &containsDotsAndDollarsField); - const StmtId stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++); + const StmtId stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex); const bool wasAlreadyExecuted = opCtx->isRetryableWrite() && txnParticipant.checkStatementExecutedNoOplogEntryFetch(opCtx, stmtId); @@ -1041,7 +1043,7 @@ WriteResult performUpdates(OperationContext* opCtx, bool containsRetry = false; ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); }); - size_t stmtIdIndex = 0; + size_t nextOpIndex = 0; WriteResult out; out.results.reserve(wholeOp.getUpdates().size()); @@ -1054,7 +1056,8 @@ WriteResult performUpdates(OperationContext* opCtx, // updates. bool forgoOpCounterIncrements = false; for (auto&& singleOp : wholeOp.getUpdates()) { - const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++); + const auto currentOpIndex = nextOpIndex++; + const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex); if (opCtx->isRetryableWrite()) { if (auto entry = txnParticipant.checkStatementExecuted(opCtx, stmtId)) { containsRetry = true; @@ -1081,6 +1084,13 @@ WriteResult performUpdates(OperationContext* opCtx, finishCurOp(opCtx, &*curOp); } }); + + if (analyze_shard_key::supportsPersistingSampledQueries() && singleOp.getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addUpdateQuery(wholeOp, currentOpIndex) + .getAsync([](auto) {}); + } + try { lastOpFixer.startingOp(); @@ -1276,7 +1286,7 @@ WriteResult performDeletes(OperationContext* opCtx, bool containsRetry = false; ON_BLOCK_EXIT([&] { updateRetryStats(opCtx, containsRetry); }); - size_t stmtIdIndex = 0; + size_t nextOpIndex = 0; WriteResult out; out.results.reserve(wholeOp.getDeletes().size()); @@ -1286,7 +1296,8 @@ WriteResult performDeletes(OperationContext* opCtx, wholeOp.getLegacyRuntimeConstants().value_or(Variables::generateRuntimeConstants(opCtx)); for (auto&& singleOp : wholeOp.getDeletes()) { - const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, stmtIdIndex++); + const auto currentOpIndex = nextOpIndex++; + const auto stmtId = getStmtIdForWriteOp(opCtx, wholeOp, currentOpIndex); if (opCtx->isRetryableWrite() && txnParticipant.checkStatementExecutedNoOplogEntryFetch(opCtx, stmtId)) { containsRetry = true; @@ -1316,6 +1327,13 @@ WriteResult performDeletes(OperationContext* opCtx, &hangBeforeChildRemoveOpIsPopped, opCtx, "hangBeforeChildRemoveOpIsPopped"); } }); + + if (analyze_shard_key::supportsPersistingSampledQueries() && singleOp.getSampleId()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addDeleteQuery(wholeOp, currentOpIndex) + .getAsync([](auto) {}); + } + try { lastOpFixer.startingOp(); out.results.push_back(performSingleDeleteOp(opCtx, diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 2347bd26c10..b8b17ac1ab9 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -503,6 +503,22 @@ env.Library( LIBDEPS_PRIVATE=[], ) +env.Benchmark( + target='abt_translation_bm', + source=[ + 'abt/abt_translate_bm_fixture.cpp', + 'abt/abt_translate_cq_bm.cpp', + 'abt/abt_translate_pipeline_bm.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/db/pipeline/pipeline', + '$BUILD_DIR/mongo/db/query/canonical_query', + '$BUILD_DIR/mongo/db/query/query_test_service_context', + '$BUILD_DIR/mongo/unittest/unittest', + '$BUILD_DIR/mongo/util/processinfo', + ], +) + env.CppUnitTest( target='db_pipeline_test', source=[ diff --git a/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp new file mode 100644 index 00000000000..6b00208dbbd --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.cpp @@ -0,0 +1,211 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h" + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/json.h" + +namespace mongo { +namespace { +BSONArray buildArray(int size) { + BSONArrayBuilder builder; + for (int i = 0; i < size; i++) { + builder.append(i); + } + return builder.arr(); +} + +std::string getField(int index) { + static constexpr StringData kViableChars = "abcdefghijklmnopqrstuvwxyz"_sd; + invariant(size_t(index) < kViableChars.size()); + return std::string(1, kViableChars[index]); +} + +BSONObj buildSimpleBSONSpec(int nFields, bool isMatch, bool isExclusion = false) { + BSONObjBuilder spec; + for (auto i = 0; i < nFields; i++) { + int val = isMatch ? i : (isExclusion ? 0 : 1); + spec.append(getField(i), val); + } + return spec.obj(); +} + +/** + * Builds a filter BSON with 'nFields' simple equality predicates. + */ +BSONObj buildSimpleMatchSpec(int nFields) { + return buildSimpleBSONSpec(nFields, true /*isMatch*/); +} + +/** + * Builds a projection BSON with 'nFields' simple inclusions or exclusions, depending on the + * 'isExclusion' parameter. + */ +BSONObj buildSimpleProjectSpec(int nFields, bool isExclusion) { + return buildSimpleBSONSpec(nFields, false /*isMatch*/, isExclusion); +} + +BSONObj buildNestedBSONSpec(int depth, bool isExclusion, int offset) { + std::string field; + for (auto i = 0; i < depth - 1; i++) { + field += getField(offset + i) += "."; + } + field += getField(offset + depth); + + return BSON(field << (isExclusion ? 0 : 1)); +} + +/** + * Builds a BSON representing a predicate on one dotted path, where the field has depth 'depth'. + */ +BSONObj buildNestedMatchSpec(int depth, int offset = 0) { + return buildNestedBSONSpec(depth, false /*isExclusion*/, offset); +} + +/** + * Builds a BSON representing a projection on one dotted path, where the field has depth 'depth'. + */ +BSONObj buildNestedProjectSpec(int depth, bool isExclusion, int offset = 0) { + return buildNestedBSONSpec(depth, isExclusion, offset); +} +} // namespace + +void ABTTranslateBenchmarkFixture::benchmarkMatch(benchmark::State& state) { + auto match = buildSimpleMatchSpec(1); + benchmarkABTTranslate(state, match, BSONObj()); +} +void ABTTranslateBenchmarkFixture::benchmarkMatchTwoFields(benchmark::State& state) { + auto match = buildSimpleMatchSpec(2); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchTwentyFields(benchmark::State& state) { + auto match = buildSimpleMatchSpec(20); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchDepthTwo(benchmark::State& state) { + auto match = buildNestedMatchSpec(2); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchDepthTwenty(benchmark::State& state) { + auto match = buildNestedMatchSpec(20); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchGtLt(benchmark::State& state) { + auto match = fromjson("{a: {$gt: -12, $lt: 5}}"); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchIn(benchmark::State& state) { + auto match = BSON("a" << BSON("$in" << buildArray(10))); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchInLarge(benchmark::State& state) { + auto match = BSON("a" << BSON("$in" << buildArray(1000))); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchElemMatch(benchmark::State& state) { + auto match = fromjson("{a: {$elemMatch: {b: {$eq: 2}, c: {$lt: 3}}}}"); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkMatchComplex(benchmark::State& state) { + auto match = fromjson( + "{$and: [" + "{'a.b': {$not: {$eq: 2}}}," + "{'b.c': {$lte: {$eq: 'str'}}}," + "{$or: [{'c.d' : {$eq: 3}}, {'d.e': {$eq: 4}}]}," + "{$or: [" + "{'e.f': {$gt: 4}}," + "{$and: [" + "{'f.g': {$not: {$eq: 1}}}," + "{'g.h': {$eq: 3}}" + "]}" + "]}" + "]}}"); + benchmarkABTTranslate(state, match, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectExclude(benchmark::State& state) { + auto project = buildSimpleProjectSpec(1, true /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectInclude(benchmark::State& state) { + auto project = buildSimpleProjectSpec(1, false /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeTwoFields(benchmark::State& state) { + auto project = buildSimpleProjectSpec(2, false /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeTwentyFields(benchmark::State& state) { + auto project = buildSimpleProjectSpec(20, false /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeDepthTwo(benchmark::State& state) { + auto project = buildNestedProjectSpec(2, false /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkProjectIncludeDepthTwenty(benchmark::State& state) { + auto project = buildNestedProjectSpec(20, false /*isExclusion*/); + benchmarkABTTranslate(state, project, BSONObj()); +} + +void ABTTranslateBenchmarkFixture::benchmarkTwoStages(benchmark::State& state) { + // Builds a match on a nested field and then excludes that nested field. + std::vector<BSONObj> pipeline; + pipeline.push_back(BSON("$match" << buildNestedMatchSpec(3))); + pipeline.push_back(BSON("$project" << buildNestedProjectSpec(3, true /*isExclusion*/))); + benchmarkABTTranslate(state, pipeline); +} + +void ABTTranslateBenchmarkFixture::benchmarkTwentyStages(benchmark::State& state) { + // Builds a sequence of alternating $match and $project stages which match on a nested field and + // then exclude that field. + std::vector<BSONObj> pipeline; + for (int i = 0; i < 10; i++) { + pipeline.push_back(BSON("$match" << buildNestedMatchSpec(3, i))); + pipeline.push_back(BSON("$project" << buildNestedProjectSpec(3, true /*exclusion*/, i))); + } + benchmarkABTTranslate(state, pipeline); +} + + +} // namespace mongo diff --git a/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h new file mode 100644 index 00000000000..0ab40ed52fe --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_translate_bm_fixture.h @@ -0,0 +1,137 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/platform/basic.h" + +#include <benchmark/benchmark.h> + +#include "mongo/bson/bsonobj.h" + +namespace mongo { + +class ABTTranslateBenchmarkFixture : public benchmark::Fixture { +public: + virtual void benchmarkABTTranslate(benchmark::State& state, + BSONObj matchSpec, + BSONObj projectSpec) = 0; + virtual void benchmarkABTTranslate(benchmark::State& state, + const std::vector<BSONObj>& pipeline) = 0; + + void benchmarkMatch(benchmark::State& state); + void benchmarkMatchTwoFields(benchmark::State& state); + void benchmarkMatchTwentyFields(benchmark::State& state); + + void benchmarkMatchDepthTwo(benchmark::State& state); + void benchmarkMatchDepthTwenty(benchmark::State& state); + + void benchmarkMatchGtLt(benchmark::State& state); + void benchmarkMatchIn(benchmark::State& state); + void benchmarkMatchInLarge(benchmark::State& state); + void benchmarkMatchElemMatch(benchmark::State& state); + void benchmarkMatchComplex(benchmark::State& state); + + void benchmarkProjectExclude(benchmark::State& state); + void benchmarkProjectInclude(benchmark::State& state); + void benchmarkProjectIncludeTwoFields(benchmark::State& state); + void benchmarkProjectIncludeTwentyFields(benchmark::State& state); + + void benchmarkProjectIncludeDepthTwo(benchmark::State& state); + void benchmarkProjectIncludeDepthTwenty(benchmark::State& state); + + void benchmarkTwoStages(benchmark::State& state); + void benchmarkTwentyStages(benchmark::State& state); +}; + +// These benchmarks cover some simple queries which are currently CQF-eligible. As more support +// is added to CQF, more benchmarks may be added here as needed. +#define BENCHMARK_MQL_TRANSLATION(Fixture) \ + \ + BENCHMARK_F(Fixture, Match)(benchmark::State & state) { \ + benchmarkMatch(state); \ + } \ + BENCHMARK_F(Fixture, MatchTwoFields)(benchmark::State & state) { \ + benchmarkMatchTwoFields(state); \ + } \ + BENCHMARK_F(Fixture, MatchTwentyFields)(benchmark::State & state) { \ + benchmarkMatchTwentyFields(state); \ + } \ + BENCHMARK_F(Fixture, MatchDepthTwo)(benchmark::State & state) { \ + benchmarkMatchDepthTwo(state); \ + } \ + BENCHMARK_F(Fixture, MatchDepthTwenty)(benchmark::State & state) { \ + benchmarkMatchDepthTwenty(state); \ + } \ + BENCHMARK_F(Fixture, MatchGtLt)(benchmark::State & state) { \ + benchmarkMatchGtLt(state); \ + } \ + BENCHMARK_F(Fixture, MatchIn)(benchmark::State & state) { \ + benchmarkMatchIn(state); \ + } \ + BENCHMARK_F(Fixture, MatchInLarge)(benchmark::State & state) { \ + benchmarkMatchInLarge(state); \ + } \ + BENCHMARK_F(Fixture, MatchElemMatch)(benchmark::State & state) { \ + benchmarkMatchElemMatch(state); \ + } \ + BENCHMARK_F(Fixture, MatchComplex)(benchmark::State & state) { \ + benchmarkMatchComplex(state); \ + } \ + BENCHMARK_F(Fixture, ProjectExclude)(benchmark::State & state) { \ + benchmarkProjectExclude(state); \ + } \ + BENCHMARK_F(Fixture, ProjectInclude)(benchmark::State & state) { \ + benchmarkProjectInclude(state); \ + } \ + BENCHMARK_F(Fixture, ProjectIncludeTwoFields)(benchmark::State & state) { \ + benchmarkProjectIncludeTwoFields(state); \ + } \ + BENCHMARK_F(Fixture, ProjectIncludeTwentyFields)(benchmark::State & state) { \ + benchmarkProjectIncludeTwentyFields(state); \ + } \ + BENCHMARK_F(Fixture, ProjectIncludeDepthTwo)(benchmark::State & state) { \ + benchmarkProjectIncludeDepthTwo(state); \ + } \ + BENCHMARK_F(Fixture, ProjectIncludeDepthTwenty)(benchmark::State & state) { \ + benchmarkProjectIncludeDepthTwenty(state); \ + } + +// Queries which are expressed as pipelines should be added here because they cannot go through +// find translation. +#define BENCHMARK_MQL_PIPELINE_TRANSLATION(Fixture) \ + \ + BENCHMARK_F(Fixture, TwoStages)(benchmark::State & state) { \ + benchmarkTwoStages(state); \ + } \ + BENCHMARK_F(Fixture, TwentyStages)(benchmark::State & state) { \ + benchmarkTwentyStages(state); \ + } + +} // namespace mongo diff --git a/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp b/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp new file mode 100644 index 00000000000..8603ea2d7c0 --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_translate_cq_bm.cpp @@ -0,0 +1,88 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <benchmark/benchmark.h> + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h" +#include "mongo/db/pipeline/abt/canonical_query_translation.h" +#include "mongo/db/query/canonical_query.h" +#include "mongo/db/query/query_test_service_context.h" + +namespace mongo::optimizer { +namespace { +/** + * Benchmarks translation from CanonicalQuery to ABT. + */ +class CanonicalQueryABTTranslate : public ABTTranslateBenchmarkFixture { +public: + CanonicalQueryABTTranslate() {} + + void benchmarkABTTranslate(benchmark::State& state, + const std::vector<BSONObj>& pipeline) override final { + state.SkipWithError("Find translation fixture cannot translate a pieline"); + return; + } + + void benchmarkABTTranslate(benchmark::State& state, + BSONObj matchSpec, + BSONObj projectSpec) override final { + QueryTestServiceContext testServiceContext; + auto opCtx = testServiceContext.makeOperationContext(); + auto nss = NamespaceString("test.bm"); + + Metadata metadata{{}}; + PrefixId prefixId; + std::string scanProjName = prefixId.getNextId("scan"); + + auto findCommand = std::make_unique<FindCommandRequest>(nss); + findCommand->setFilter(matchSpec); + findCommand->setProjection(projectSpec); + auto cq = CanonicalQuery::canonicalize(opCtx.get(), std::move(findCommand)); + if (!cq.isOK()) { + state.SkipWithError("Canonical query could not be created"); + return; + } + + // This is where recording starts. + for (auto keepRunning : state) { + benchmark::DoNotOptimize( + translateCanonicalQueryToABT(metadata, + *cq.getValue(), + scanProjName, + make<ScanNode>(scanProjName, "collection"), + prefixId)); + benchmark::ClobberMemory(); + } + } +}; + +BENCHMARK_MQL_TRANSLATION(CanonicalQueryABTTranslate) +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp b/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp new file mode 100644 index 00000000000..f878c494c13 --- /dev/null +++ b/src/mongo/db/pipeline/abt/abt_translate_pipeline_bm.cpp @@ -0,0 +1,90 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <benchmark/benchmark.h> + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/pipeline/abt/abt_translate_bm_fixture.h" +#include "mongo/db/pipeline/abt/document_source_visitor.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/query/query_test_service_context.h" + +namespace mongo::optimizer { +namespace { +/** + * Benchmarks translation from optimized Pipeline to ABT. + */ +class PipelineABTTranslateBenchmark : public ABTTranslateBenchmarkFixture { +public: + PipelineABTTranslateBenchmark() {} + + void benchmarkABTTranslate(benchmark::State& state, + BSONObj matchSpec, + BSONObj projectSpec) override final { + std::vector<BSONObj> pipeline; + if (!matchSpec.isEmpty()) { + pipeline.push_back(BSON("$match" << matchSpec)); + } + if (!projectSpec.isEmpty()) { + pipeline.push_back(BSON("$project" << projectSpec)); + } + benchmarkABTTranslate(state, pipeline); + } + + void benchmarkABTTranslate(benchmark::State& state, + const std::vector<BSONObj>& pipeline) override final { + QueryTestServiceContext testServiceContext; + auto opCtx = testServiceContext.makeOperationContext(); + auto expCtx = new ExpressionContextForTest(opCtx.get(), NamespaceString("test.bm")); + + Metadata metadata{{}}; + PrefixId prefixId; + std::string scanProjName = prefixId.getNextId("scan"); + + std::unique_ptr<Pipeline, PipelineDeleter> parsedPipeline = + Pipeline::parse(pipeline, expCtx); + parsedPipeline->optimizePipeline(); + + // This is where recording starts. + for (auto keepRunning : state) { + benchmark::DoNotOptimize( + translatePipelineToABT(metadata, + *parsedPipeline, + scanProjName, + make<ScanNode>(scanProjName, "collection"), + prefixId)); + benchmark::ClobberMemory(); + } + } +}; + +BENCHMARK_MQL_TRANSLATION(PipelineABTTranslateBenchmark) +BENCHMARK_MQL_PIPELINE_TRANSLATION(PipelineABTTranslateBenchmark) +} // namespace +} // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h index 7cbf740c2e2..7c852281078 100644 --- a/src/mongo/db/pipeline/abt/expr_algebrizer_context.h +++ b/src/mongo/db/pipeline/abt/expr_algebrizer_context.h @@ -31,6 +31,7 @@ #include <stack> +#include "mongo/db/matcher/expression_path.h" #include "mongo/db/query/optimizer/node.h" #include "mongo/db/query/optimizer/utils/utils.h" @@ -73,17 +74,37 @@ public: */ std::string getNextId(const std::string& key); - void enterElemMatch() { - _elemMatchCount++; + void enterElemMatch(const MatchExpression::MatchType matchType) { + _elemMatchStack.push_back(matchType); } void exitElemMatch() { tassert(6809501, "Attempting to exit out of elemMatch that was not entered", inElemMatch()); - _elemMatchCount--; + _elemMatchStack.pop_back(); } bool inElemMatch() { - return _elemMatchCount > 0; + return !_elemMatchStack.empty(); + } + + /** + * Returns whether the current $elemMatch should consider its path for translation. This + * function assumes that 'enterElemMatch' has been called before visiting the current + * expression. + */ + bool shouldGeneratePathForElemMatch() const { + return _elemMatchStack.size() == 1 || + _elemMatchStack[_elemMatchStack.size() - 2] == + MatchExpression::MatchType::ELEM_MATCH_OBJECT; + } + + /** + * Returns true if the current expression should consider its path for translation based on + * whether it's contained within an ElemMatchObjectExpression. + */ + bool shouldGeneratePath() const { + return _elemMatchStack.empty() || + _elemMatchStack.back() == MatchExpression::MatchType::ELEM_MATCH_OBJECT; } private: @@ -103,8 +124,9 @@ private: // child expressions. std::stack<ABT> _stack; - // Track whether the vistor is currently under an $elemMatch node. - int _elemMatchCount{0}; + // Used to track expressions contained under an $elemMatch. Each entry is either an + // ELEM_MATCH_OBJECT or ELEM_MATCH_VALUE. + std::vector<MatchExpression::MatchType> _elemMatchStack; }; } // namespace mongo::optimizer diff --git a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp index e422dc24d63..70be6fd3651 100644 --- a/src/mongo/db/pipeline/abt/match_expression_visitor.cpp +++ b/src/mongo/db/pipeline/abt/match_expression_visitor.cpp @@ -73,11 +73,11 @@ public: ABTMatchExpressionPreVisitor(ExpressionAlgebrizerContext& ctx) : _ctx(ctx) {} void visit(const ElemMatchObjectMatchExpression* expr) override { - _ctx.enterElemMatch(); + _ctx.enterElemMatch(expr->matchType()); } void visit(const ElemMatchValueMatchExpression* expr) override { - _ctx.enterElemMatch(); + _ctx.enterElemMatch(expr->matchType()); } private: @@ -135,8 +135,8 @@ public: assertSupportedPathExpression(expr); ABT result = make<PathDefault>(Constant::boolean(false)); - if (!expr->path().empty()) { - result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + if (shouldGeneratePath(expr)) { + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } _ctx.push(std::move(result)); } @@ -227,9 +227,9 @@ public: maybeComposePath<PathComposeA>(result, make<PathDefault>(Constant::boolean(true))); } - // The path can be empty if we are within an $elemMatch. In this case elemMatch would - // insert a traverse. - if (!expr->path().empty()) { + // Do not insert a traverse if within an $elemMatch; traversal will be handled by the + // $elemMatch expression itself. + if (shouldGeneratePath(expr)) { // When the path we are comparing is a path to an array, the comparison is // considered true if it evaluates to true for the array itself or for any of the // array’s elements. 'result' evaluates comparison on the array elements, and @@ -249,7 +249,7 @@ public: make<Constant>(tagArraysOnly, valArraysOnly))); arrOnlyGuard.reset(); } - result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } _ctx.push(std::move(result)); } @@ -426,14 +426,8 @@ public: make<FunctionCall>("getArraySize", makeSeq(make<Variable>(lambdaProjName))), Constant::int64(expr->getData())))); - if (!expr->path().empty()) { - // No traverse. - result = translateFieldPath( - FieldPath(expr->path().toString()), - std::move(result), - [](const std::string& fieldName, const bool /*isLastElement*/, ABT input) { - return make<PathGet>(fieldName, std::move(input)); - }); + if (shouldGeneratePath(expr)) { + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } _ctx.push(std::move(result)); } @@ -460,9 +454,7 @@ public: makeSeq(make<Variable>(lambdaProjName), Constant::int32(expr->typeSet().getBSONTypeMask()))))); - // The path can be empty if we are within an $elemMatch. In this case elemMatch would insert - // a traverse. - if (!expr->path().empty()) { + if (shouldGeneratePath(expr)) { result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel); if (expr->typeSet().hasType(BSONType::Array)) { // If we are testing against array type, insert a comparison against the @@ -470,7 +462,7 @@ public: result = make<PathComposeA>(make<PathArr>(), std::move(result)); } - result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } _ctx.push(std::move(result)); } @@ -517,33 +509,13 @@ private: // Make sure we consider only arrays fields on the path. maybeComposePath(result, make<PathArr>()); - if (!expr->path().empty()) { - result = translateFieldPath( - FieldPath{expr->path().toString()}, - std::move(result), - [&](const std::string& fieldName, const bool isLastElement, ABT input) { - if (!isLastElement) { - input = make<PathTraverse>(std::move(input), PathTraverse::kSingleLevel); - } - return make<PathGet>(fieldName, std::move(input)); - }); + if (shouldGeneratePath(expr)) { + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } _ctx.push(std::move(result)); } - ABT generateFieldPath(const FieldPath& fieldPath, ABT initial) { - return translateFieldPath( - fieldPath, - std::move(initial), - [&](const std::string& fieldName, const bool isLastElement, ABT input) { - if (!isLastElement) { - input = make<PathTraverse>(std::move(input), PathTraverse::kSingleLevel); - } - return make<PathGet>(fieldName, std::move(input)); - }); - } - void assertSupportedPathExpression(const PathMatchExpression* expr) { uassert(ErrorCodes::InternalErrorNotSupported, "Expression contains a numeric path component", @@ -610,9 +582,7 @@ private: break; } - // The path can be empty if we are within an $elemMatch. In this case elemMatch would - // insert a traverse. - if (!expr->path().empty()) { + if (shouldGeneratePath(expr)) { if (tag == sbe::value::TypeTags::Array || tag == sbe::value::TypeTags::MinKey || tag == sbe::value::TypeTags::MaxKey) { // The behavior of PathTraverse when it encounters an array is to apply its subpath @@ -628,8 +598,9 @@ private: result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel); } - result = generateFieldPath(FieldPath(expr->path().toString()), std::move(result)); + result = translateFieldRef(*(expr->fieldRef()), std::move(result)); } + _ctx.push(std::move(result)); } @@ -659,6 +630,23 @@ private: str::stream() << "Match expression is not supported: " << expr->matchType()); } + /** + * Returns whether the currently visiting expression should consider the path it's operating on + * and build the appropriate ABT. This can return false for expressions within an $elemMatch + * that operate against each value in an array (aka "elemMatch value"). + */ + bool shouldGeneratePath(const PathMatchExpression* expr) const { + // The only case where any expression, including $elemMatch, should ignore it's path is if + // it's directly under a value $elemMatch. The 'elemMatchStack' includes 'expr' if it's an + // $elemMatch, so we need to look back an extra element. + if (expr->matchType() == MatchExpression::MatchType::ELEM_MATCH_OBJECT || + expr->matchType() == MatchExpression::MatchType::ELEM_MATCH_VALUE) { + return _ctx.shouldGeneratePathForElemMatch(); + } + + return _ctx.shouldGeneratePath(); + } + // If we are parsing a partial index filter, we don't allow agg expressions. const bool _allowAggExpressions; diff --git a/src/mongo/db/pipeline/abt/utils.cpp b/src/mongo/db/pipeline/abt/utils.cpp index c4d6093cc1b..9911e090305 100644 --- a/src/mongo/db/pipeline/abt/utils.cpp +++ b/src/mongo/db/pipeline/abt/utils.cpp @@ -30,7 +30,7 @@ #include "mongo/db/pipeline/abt/utils.h" #include "mongo/db/exec/sbe/values/bson.h" - +#include "mongo/db/query/optimizer/utils/utils.h" namespace mongo::optimizer { @@ -81,6 +81,39 @@ ABT translateFieldPath(const FieldPath& fieldPath, return result; } +ABT translateFieldRef(const FieldRef& fieldRef, ABT initial) { + ABT result = std::move(initial); + + const size_t fieldPathLength = fieldRef.numParts(); + + // Handle empty field paths separately. + if (fieldPathLength == 0) { + return make<PathGet>("", std::move(result)); + } + + for (size_t i = fieldPathLength; i-- > 0;) { + // A single empty field path will parse to a FieldRef with 0 parts but should + // logically be considered a single part with an empty string. + if (i != fieldPathLength - 1) { + // For field paths with empty elements such as 'x.', we should traverse the + // array 'x' but not reach into any sub-objects. So a predicate such as {'x.': + // {$eq: 5}} should match {x: [5]} and {x: {"": 5}} but not {x: [{"": 5}]}. + const bool trailingEmptyPath = + (fieldPathLength >= 2u && i == fieldPathLength - 2u) && (fieldRef[i + 1] == ""_sd); + if (trailingEmptyPath) { + auto arrCase = make<PathArr>(); + maybeComposePath(arrCase, result.cast<PathGet>()->getPath()); + maybeComposePath<PathComposeA>(result, arrCase); + } else { + result = make<PathTraverse>(std::move(result), PathTraverse::kSingleLevel); + } + } + result = make<PathGet>(fieldRef[i].toString(), std::move(result)); + } + + return result; +} + std::pair<boost::optional<ABT>, bool> getMinMaxBoundForType(const bool isMin, const sbe::value::TypeTags& tag) { switch (tag) { diff --git a/src/mongo/db/pipeline/abt/utils.h b/src/mongo/db/pipeline/abt/utils.h index 3d7c2906979..54a255adff7 100644 --- a/src/mongo/db/pipeline/abt/utils.h +++ b/src/mongo/db/pipeline/abt/utils.h @@ -40,12 +40,22 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val); using ABTFieldNameFn = std::function<ABT(const std::string& fieldName, const bool isLastElement, ABT input)>; + +/** + * Translates an aggregation FieldPath by invoking the `fieldNameFn` for each path component. + */ ABT translateFieldPath(const FieldPath& fieldPath, ABT initial, const ABTFieldNameFn& fieldNameFn, size_t skipFromStart = 0); /** + * Translates a given FieldRef (typically used in a MatchExpression) with 'initial' as the input + * ABT. + */ +ABT translateFieldRef(const FieldRef& fieldRef, ABT initial); + +/** * Return the minimum or maximum value for the "class" of values represented by the input * constant. Used to support type bracketing. * Return format is <min/max value, bool inclusive> diff --git a/src/mongo/db/pipeline/aggregate_command.idl b/src/mongo/db/pipeline/aggregate_command.idl index 5e3cfc5ba5d..9cb77333eef 100644 --- a/src/mongo/db/pipeline/aggregate_command.idl +++ b/src/mongo/db/pipeline/aggregate_command.idl @@ -286,4 +286,9 @@ commands: type: array<ExternalDataSourceOption> cpp_name: externalDataSources optional: true - stability: unstable
\ No newline at end of file + stability: unstable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable diff --git a/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp b/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp index f9c59747448..cb2afcf06e2 100644 --- a/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp +++ b/src/mongo/db/pipeline/change_stream_expired_pre_image_remover.cpp @@ -86,6 +86,7 @@ public: } void run() { + LOGV2(7080100, "Starting Change Stream Expired Pre-images Remover thread"); ThreadClient tc(name(), getGlobalServiceContext()); AuthorizationSession::get(cc())->grantInternalAuthorization(&cc()); diff --git a/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp b/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp index e93e68aa0ca..d0fc5305029 100644 --- a/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp +++ b/src/mongo/db/pipeline/document_source_check_resume_token_test.cpp @@ -206,10 +206,6 @@ protected: if (!_collScan) { _collScan = std::make_unique<CollectionScan>( pExpCtx.get(), _collectionPtr, _params, &_ws, _filter.get()); - // The first call to doWork will create the cursor and return NEED_TIME. But it won't - // actually scan any of the documents that are present in the mock cursor queue. - ASSERT_EQ(_collScan->doWork(nullptr), PlanStage::NEED_TIME); - ASSERT_EQ(_getNumDocsTested(), 0); } while (true) { // If the next result is a pause, return it and don't collscan. diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp index 9ef6d892a51..89959bb0b49 100644 --- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp +++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.cpp @@ -46,6 +46,7 @@ #include "mongo/db/matcher/expression_geo.h" #include "mongo/db/matcher/expression_internal_bucket_geo_within.h" #include "mongo/db/matcher/expression_internal_expr_comparison.h" +#include "mongo/db/matcher/match_expression_dependencies.h" #include "mongo/db/pipeline/accumulator_multi.h" #include "mongo/db/pipeline/document_source_add_fields.h" #include "mongo/db/pipeline/document_source_geo_near.h" @@ -250,6 +251,35 @@ DocumentSourceInternalUnpackBucket::DocumentSourceInternalUnpackBucket( _bucketUnpacker(std::move(bucketUnpacker)), _bucketMaxSpanSeconds{bucketMaxSpanSeconds} {} +DocumentSourceInternalUnpackBucket::DocumentSourceInternalUnpackBucket( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BucketUnpacker bucketUnpacker, + int bucketMaxSpanSeconds, + const boost::optional<BSONObj>& eventFilterBson, + const boost::optional<BSONObj>& wholeBucketFilterBson, + bool assumeNoMixedSchemaData) + : DocumentSourceInternalUnpackBucket( + expCtx, std::move(bucketUnpacker), bucketMaxSpanSeconds, assumeNoMixedSchemaData) { + if (eventFilterBson) { + _eventFilterBson = eventFilterBson->getOwned(); + _eventFilter = + uassertStatusOK(MatchExpressionParser::parse(_eventFilterBson, + pExpCtx, + ExtensionsCallbackNoop(), + Pipeline::kAllowedMatcherFeatures)); + _eventFilterDeps = {}; + match_expression::addDependencies(_eventFilter.get(), &_eventFilterDeps); + } + if (wholeBucketFilterBson) { + _wholeBucketFilterBson = wholeBucketFilterBson->getOwned(); + _wholeBucketFilter = + uassertStatusOK(MatchExpressionParser::parse(_wholeBucketFilterBson, + pExpCtx, + ExtensionsCallbackNoop(), + Pipeline::kAllowedMatcherFeatures)); + } +} + boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createFromBsonInternal( BSONElement specElem, const boost::intrusive_ptr<ExpressionContext>& expCtx) { uassert(5346500, @@ -267,6 +297,8 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF auto bucketMaxSpanSeconds = 0; auto assumeClean = false; std::vector<std::string> computedMetaProjFields; + boost::optional<BSONObj> eventFilterBson; + boost::optional<BSONObj> wholeBucketFilterBson; for (auto&& elem : specElem.embeddedObject()) { auto fieldName = elem.fieldNameStringData(); if (fieldName == kInclude || fieldName == kExclude) { @@ -360,6 +392,18 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF << " field must be a bool, got: " << elem.type(), elem.type() == BSONType::Bool); bucketSpec.setUsesExtendedRange(elem.boolean()); + } else if (fieldName == kEventFilter) { + uassert(7026902, + str::stream() << kEventFilter + << " field must be an object, got: " << elem.type(), + elem.type() == BSONType::Object); + eventFilterBson = elem.Obj(); + } else if (fieldName == kWholeBucketFilter) { + uassert(7026903, + str::stream() << kWholeBucketFilter + << " field must be an object, got: " << elem.type(), + elem.type() == BSONType::Object); + wholeBucketFilterBson = elem.Obj(); } else { uasserted(5346506, str::stream() @@ -378,6 +422,8 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalUnpackBucket::createF expCtx, BucketUnpacker{std::move(bucketSpec), unpackerBehavior}, bucketMaxSpanSeconds, + eventFilterBson, + wholeBucketFilterBson, assumeClean); } @@ -476,6 +522,13 @@ void DocumentSourceInternalUnpackBucket::serializeToArray( out.addField(kIncludeMaxTimeAsMetadata, Value{_bucketUnpacker.includeMaxTimeAsMetadata()}); } + if (_wholeBucketFilter) { + out.addField(kWholeBucketFilter, Value{_wholeBucketFilter->serialize()}); + } + if (_eventFilter) { + out.addField(kEventFilter, Value{_eventFilter->serialize()}); + } + if (!explain) { array.push_back(Value(DOC(getSourceName() << out.freeze()))); if (_sampleSize) { @@ -491,25 +544,49 @@ void DocumentSourceInternalUnpackBucket::serializeToArray( } } +boost::optional<Document> DocumentSourceInternalUnpackBucket::getNextMatchingMeasure() { + while (_bucketUnpacker.hasNext()) { + auto measure = _bucketUnpacker.getNext(); + if (_eventFilter) { + // MatchExpression only takes BSON documents, so we have to make one. As an + // optimization, only serialize the fields we need to do the match. + BSONObj measureBson = _eventFilterDeps.needWholeDocument + ? measure.toBson() + : document_path_support::documentToBsonWithPaths(measure, _eventFilterDeps.fields); + if (_bucketUnpacker.bucketMatchedQuery() || _eventFilter->matchesBSON(measureBson)) { + return measure; + } + } else { + return measure; + } + } + return {}; +} + DocumentSource::GetNextResult DocumentSourceInternalUnpackBucket::doGetNext() { tassert(5521502, "calling doGetNext() when '_sampleSize' is set is disallowed", !_sampleSize); // Otherwise, fallback to unpacking every measurement in all buckets until the child stage is // exhausted. - if (_bucketUnpacker.hasNext()) { - return _bucketUnpacker.getNext(); + if (auto measure = getNextMatchingMeasure()) { + return GetNextResult(std::move(*measure)); } auto nextResult = pSource->getNext(); - if (nextResult.isAdvanced()) { + while (nextResult.isAdvanced()) { auto bucket = nextResult.getDocument().toBson(); - _bucketUnpacker.reset(std::move(bucket)); + auto bucketMatchedQuery = _wholeBucketFilter && _wholeBucketFilter->matchesBSON(bucket); + _bucketUnpacker.reset(std::move(bucket), bucketMatchedQuery); + uassert(5346509, str::stream() << "A bucket with _id " << _bucketUnpacker.bucket()[timeseries::kBucketIdFieldName].toString() << " contains an empty data region", _bucketUnpacker.hasNext()); - return _bucketUnpacker.getNext(); + if (auto measure = getNextMatchingMeasure()) { + return GetNextResult(std::move(*measure)); + } + nextResult = pSource->getNext(); } return nextResult; @@ -587,7 +664,8 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje // Check for a viable inclusion $project after the $_internalUnpackBucket. auto [existingProj, isInclusion] = getIncludeExcludeProjectAndType(std::next(itr)->get()); - if (isInclusion && !existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) { + if (!_eventFilter && isInclusion && !existingProj.isEmpty() && + canInternalizeProjectObj(existingProj)) { container->erase(std::next(itr)); return {existingProj, isInclusion}; } @@ -595,8 +673,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje // Attempt to get an inclusion $project representing the root-level dependencies of the pipeline // after the $_internalUnpackBucket. If this $project is not empty, then the dependency set was // finite. - Pipeline::SourceContainer restOfPipeline(std::next(itr), container->end()); - auto deps = Pipeline::getDependenciesForContainer(pExpCtx, restOfPipeline, boost::none); + auto deps = getRestPipelineDependencies(itr, container); if (auto dependencyProj = deps.toProjectionWithoutMetadata(DepsTracker::TruncateToRootLevel::yes); !dependencyProj.isEmpty()) { @@ -604,7 +681,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje } // Check for a viable exclusion $project after the $_internalUnpackBucket. - if (!existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) { + if (!_eventFilter && !existingProj.isEmpty() && canInternalizeProjectObj(existingProj)) { container->erase(std::next(itr)); return {existingProj, isInclusion}; } @@ -612,8 +689,7 @@ std::pair<BSONObj, bool> DocumentSourceInternalUnpackBucket::extractOrBuildProje return {BSONObj{}, false}; } -std::unique_ptr<MatchExpression> -DocumentSourceInternalUnpackBucket::createPredicatesOnBucketLevelField( +BucketSpec::BucketPredicate DocumentSourceInternalUnpackBucket::createPredicatesOnBucketLevelField( const MatchExpression* matchExpr) const { return BucketSpec::createPredicatesOnBucketLevelField( matchExpr, @@ -1005,6 +1081,16 @@ bool findSequentialDocumentCache(Pipeline::SourceContainer::iterator start, return start != end; } +DepsTracker DocumentSourceInternalUnpackBucket::getRestPipelineDependencies( + Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) const { + auto deps = Pipeline::getDependenciesForContainer( + pExpCtx, Pipeline::SourceContainer{std::next(itr), container->end()}, boost::none); + if (_eventFilter) { + match_expression::addDependencies(_eventFilter.get(), &deps); + } + return deps; +} + Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimizeAt( Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) { invariant(*itr == this); @@ -1018,7 +1104,8 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi bool haveComputedMetaField = this->haveComputedMetaField(); // Before any other rewrites for the current stage, consider reordering with $sort. - if (auto sortPtr = dynamic_cast<DocumentSourceSort*>(std::next(itr)->get())) { + if (auto sortPtr = dynamic_cast<DocumentSourceSort*>(std::next(itr)->get()); + sortPtr && !_eventFilter) { if (auto metaField = _bucketUnpacker.bucketSpec().metaField(); metaField && !haveComputedMetaField) { if (checkMetadataSortReorder(sortPtr->getSortKeyPattern(), metaField.value())) { @@ -1049,7 +1136,8 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi } // Attempt to push geoNear on the metaField past $_internalUnpackBucket. - if (auto nextNear = dynamic_cast<DocumentSourceGeoNear*>(std::next(itr)->get())) { + if (auto nextNear = dynamic_cast<DocumentSourceGeoNear*>(std::next(itr)->get()); + nextNear && !_eventFilter) { // Currently we only support geo indexes on the meta field, and we enforce this by // requiring the key field to be set so we can check before we try to look up indexes. auto keyField = nextNear->getKeyField(); @@ -1130,8 +1218,7 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi { // Check if the rest of the pipeline needs any fields. For example we might only be // interested in $count. - auto deps = Pipeline::getDependenciesForContainer( - pExpCtx, Pipeline::SourceContainer{std::next(itr), container->end()}, boost::none); + auto deps = getRestPipelineDependencies(itr, container); if (deps.hasNoRequirements()) { _bucketUnpacker.setBucketSpecAndBehavior({_bucketUnpacker.bucketSpec().timeField(), _bucketUnpacker.bucketSpec().metaField(), @@ -1151,31 +1238,65 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi } // Attempt to optimize last-point type queries. - if (!_triedLastpointRewrite && optimizeLastpoint(itr, container)) { + if (!_triedLastpointRewrite && !_eventFilter && optimizeLastpoint(itr, container)) { _triedLastpointRewrite = true; // If we are able to rewrite the aggregation, give the resulting pipeline a chance to // perform further optimizations. return container->begin(); }; - // Attempt to map predicates on bucketed fields to predicates on the control field. - if (auto nextMatch = dynamic_cast<DocumentSourceMatch*>(std::next(itr)->get()); - nextMatch && !_triedBucketLevelFieldsPredicatesPushdown) { - _triedBucketLevelFieldsPredicatesPushdown = true; + // Attempt to map predicates on bucketed fields to the predicates on the control field. + if (auto nextMatch = dynamic_cast<DocumentSourceMatch*>(std::next(itr)->get())) { - if (auto match = createPredicatesOnBucketLevelField(nextMatch->getMatchExpression())) { + // Merge multiple following $match stages. + auto itrToMatch = std::next(itr); + while (std::next(itrToMatch) != container->end() && + dynamic_cast<DocumentSourceMatch*>(std::next(itrToMatch)->get())) { + nextMatch->doOptimizeAt(itrToMatch, container); + } + + auto predicates = createPredicatesOnBucketLevelField(nextMatch->getMatchExpression()); + + // Try to create a tight bucket predicate to perform bucket level matching. + if (predicates.tightPredicate) { + _wholeBucketFilterBson = predicates.tightPredicate->serialize(); + _wholeBucketFilter = + uassertStatusOK(MatchExpressionParser::parse(_wholeBucketFilterBson, + pExpCtx, + ExtensionsCallbackNoop(), + Pipeline::kAllowedMatcherFeatures)); + _wholeBucketFilter = MatchExpression::optimize(std::move(_wholeBucketFilter)); + } + + // Push the original event predicate into the unpacking stage. + _eventFilterBson = nextMatch->getQuery().getOwned(); + _eventFilter = + uassertStatusOK(MatchExpressionParser::parse(_eventFilterBson, + pExpCtx, + ExtensionsCallbackNoop(), + Pipeline::kAllowedMatcherFeatures)); + _eventFilter = MatchExpression::optimize(std::move(_eventFilter)); + _eventFilterDeps = {}; + match_expression::addDependencies(_eventFilter.get(), &_eventFilterDeps); + container->erase(std::next(itr)); + + // Create a loose bucket predicate and push it before the unpacking stage. + if (predicates.loosePredicate) { BSONObjBuilder bob; - match->serialize(&bob); + predicates.loosePredicate->serialize(&bob); container->insert(itr, DocumentSourceMatch::create(bob.obj(), pExpCtx)); // Give other stages a chance to optimize with the new $match. return std::prev(itr) == container->begin() ? std::prev(itr) : std::prev(std::prev(itr)); } + + // We have removed a $match after this stage, so we try to optimize this stage again. + return itr; } // Attempt to push down a $project on the metaField past $_internalUnpackBucket. - if (!haveComputedMetaField) { + if (!_eventFilter && !haveComputedMetaField) { if (auto [metaProject, deleteRemainder] = extractProjectForPushDown(std::next(itr)->get()); !metaProject.isEmpty()) { container->insert(itr, @@ -1194,7 +1315,7 @@ Pipeline::SourceContainer::iterator DocumentSourceInternalUnpackBucket::doOptimi // Attempt to extract computed meta projections from subsequent $project, $addFields, or $set // and push them before the $_internalunpackBucket. - if (pushDownComputedMetaProjection(itr, container)) { + if (!_eventFilter && pushDownComputedMetaProjection(itr, container)) { // We've pushed down and removed a stage after this one. Try to optimize the new stage. return std::prev(itr) == container->begin() ? std::prev(itr) : std::prev(std::prev(itr)); } diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h index 4dd93046533..a5dab5462ad 100644 --- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h +++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket.h @@ -51,6 +51,8 @@ public: static constexpr StringData kBucketMaxSpanSeconds = "bucketMaxSpanSeconds"_sd; static constexpr StringData kIncludeMinTimeAsMetadata = "includeMinTimeAsMetadata"_sd; static constexpr StringData kIncludeMaxTimeAsMetadata = "includeMaxTimeAsMetadata"_sd; + static constexpr StringData kWholeBucketFilter = "wholeBucketFilter"_sd; + static constexpr StringData kEventFilter = "eventFilter"_sd; static boost::intrusive_ptr<DocumentSource> createFromBsonInternal( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -62,6 +64,13 @@ public: int bucketMaxSpanSeconds, bool assumeNoMixedSchemaData = false); + DocumentSourceInternalUnpackBucket(const boost::intrusive_ptr<ExpressionContext>& expCtx, + BucketUnpacker bucketUnpacker, + int bucketMaxSpanSeconds, + const boost::optional<BSONObj>& eventFilterBson, + const boost::optional<BSONObj>& wholeBucketFilterBson, + bool assumeNoMixedSchemaData = false); + const char* getSourceName() const override { return kStageNameInternal.rawData(); } @@ -158,7 +167,7 @@ public: /** * Convenience wrapper around BucketSpec::createPredicatesOnBucketLevelField(). */ - std::unique_ptr<MatchExpression> createPredicatesOnBucketLevelField( + BucketSpec::BucketPredicate createPredicatesOnBucketLevelField( const MatchExpression* matchExpr) const; /** @@ -243,8 +252,14 @@ public: GetModPathsReturn getModifiedPaths() const final override; + DepsTracker getRestPipelineDependencies(Pipeline::SourceContainer::iterator itr, + Pipeline::SourceContainer* container) const; + private: GetNextResult doGetNext() final; + + boost::optional<Document> getNextMatchingMeasure(); + bool haveComputedMetaField() const; // If buckets contained a mixed type schema along some path, we have to push down special @@ -261,9 +276,13 @@ private: int _bucketMaxCount = 0; boost::optional<long long> _sampleSize; - // Used to avoid infinite loops after we step backwards to optimize a $match on bucket level - // fields, otherwise we may do an infinite number of $match pushdowns. - bool _triedBucketLevelFieldsPredicatesPushdown = false; + // Filters pushed from the later $match stages + std::unique_ptr<MatchExpression> _eventFilter; + BSONObj _eventFilterBson; + DepsTracker _eventFilterDeps; + std::unique_ptr<MatchExpression> _wholeBucketFilter; + BSONObj _wholeBucketFilterBson; + bool _optimizedEndOfPipeline = false; bool _triedInternalizeProject = false; bool _triedLastpointRewrite = false; diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp index ffe291dcbd0..63d3b4a0b23 100644 --- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp +++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/create_predicates_on_bucket_level_field_test.cpp @@ -55,10 +55,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.max.a': {$_internalExprGt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -76,10 +77,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.max.a': {$_internalExprGte: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -97,10 +99,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.min.a': {$_internalExprLt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -118,10 +121,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.min.a': {$_internalExprLte: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -139,11 +143,12 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {$and:[{'control.min.a': {$_internalExprLte: 1}}," "{'control.max.a': {$_internalExprGte: 1}}]}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -161,7 +166,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate); + ASSERT(predicate.loosePredicate); auto expected = fromjson( "{$or: [" " {$or: [" @@ -185,7 +190,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, " ]}}" " ]}" "]}"); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), expected); + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), expected); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -208,10 +214,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.max.a': {$_internalExprGt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -234,10 +241,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.max.a': {$_internalExprGte: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -260,10 +268,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.min.a': {$_internalExprLt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -286,10 +295,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {'control.min.a': {$_internalExprLte: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -312,11 +322,12 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [ {$and:[{'control.min.a': {$_internalExprLte: 1}}," "{'control.max.a': {$_internalExprGte: 1}}]}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -334,13 +345,14 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$and: [ {$or: [ {'control.max.b': {$_internalExprGt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]}," "{$type: [ \"$control.max.b\" ]} ]}} ]}," "{$or: [ {'control.min.a': {$_internalExprLt: 5}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -358,7 +370,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -376,7 +389,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [" " {'control.max.b': {$_internalExprGt: 1}}," " {$expr: {$ne: [" @@ -384,6 +397,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, " {$type: [ \"$control.max.b\" ]}" " ]}}" "]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -402,7 +416,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$and: [ {$or: [ {'control.max.b': {$_internalExprGte: 2}}," "{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]}," "{$type: [ \"$control.max.b\" ]} ]}} ]}," @@ -412,6 +426,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, "{$or: [ {'control.min.a': {$_internalExprLt: 5}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]} ]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -429,8 +444,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT(predicate.loosePredicate); + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [" " {$or: [" " {'control.max.b': {$_internalExprGt: 1}}," @@ -447,6 +462,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, " ]}}" " ]}" "]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -464,7 +480,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -485,7 +502,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, // When a predicate can't be pushed down, it's the same as pushing down a trivially-true // predicate. So when any child of an $or can't be pushed down, we could generate something like // {$or: [ ... {$alwaysTrue: {}}, ... ]}, but then we might as well not push down the whole $or. - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -504,7 +522,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$or: [" " {$or: [" " {'control.max.b': {$_internalExprGte: 2}}," @@ -530,6 +548,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, " ]}" " ]}" "]}")); + ASSERT_FALSE(predicate.tightPredicate); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -541,19 +560,21 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, ASSERT_EQ(pipeline->getSources().size(), 2U); pipeline->optimizePipeline(); - ASSERT_EQ(pipeline->getSources().size(), 3U); + ASSERT_EQ(pipeline->getSources().size(), 2U); // To get the optimized $match from the pipeline, we have to serialize with explain. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(stages.size(), 3U); + ASSERT_EQ(stages.size(), 2U); ASSERT_BSONOBJ_EQ(stages[0].getDocument().toBson(), fromjson("{$match: {$or: [ {'control.max.b': {$_internalExprGt: 1}}," "{$expr: {$ne: [ {$type: [ \"$control.min.b\" ]}," "{$type: [ \"$control.max.b\" ]} ]}} ]}}")); - ASSERT_BSONOBJ_EQ(stages[1].getDocument().toBson(), unpackBucketObj); - ASSERT_BSONOBJ_EQ(stages[2].getDocument().toBson(), - fromjson("{$match: {$and: [{b: {$gt: 1}}, {a: {$not: {$eq: 5}}}]}}")); + ASSERT_BSONOBJ_EQ( + stages[1].getDocument().toBson(), + fromjson( + "{$_internalUnpackBucket: {exclude: [], timeField: 'time', bucketMaxSpanSeconds: 3600, " + "eventFilter: { $and: [ { b: { $gt: 1 } }, { a: { $not: { $eq: 5 } } } ] }}}")); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -566,10 +587,10 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, ASSERT_EQ(pipeline->getSources().size(), 2U); pipeline->optimizePipeline(); - ASSERT_EQ(pipeline->getSources().size(), 3U); + ASSERT_EQ(pipeline->getSources().size(), 2U); auto stages = pipeline->serializeToBson(); - ASSERT_EQ(stages.size(), 3U); + ASSERT_EQ(stages.size(), 2U); ASSERT_BSONOBJ_EQ( stages[0], @@ -582,8 +603,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, "{$or: [ {'control.min.a': {$_internalExprLt: 5}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}")); - ASSERT_BSONOBJ_EQ(stages[1], unpackBucketObj); - ASSERT_BSONOBJ_EQ(stages[2], matchObj); + ASSERT_BSONOBJ_EQ(stages[1], + fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "bucketMaxSpanSeconds: 3600," + "eventFilter: { $and: [ { b: { $gte: 2 } }, { c: { $gt: 1 } }, { a: " + "{ $lt: 5 } } ] } } }")); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -601,7 +625,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -619,7 +644,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -637,7 +663,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT(predicate == nullptr); + ASSERT(predicate.loosePredicate == nullptr); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -656,7 +683,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, ->createPredicatesOnBucketLevelField(original->getMatchExpression()); // Meta predicates are mapped to the meta field, not the control min/max fields. - ASSERT_BSONOBJ_EQ(predicate->serialize(true), fromjson("{meta: {$gt: 5}}")); + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{meta: {$gt: 5}}")); + ASSERT_BSONOBJ_EQ(predicate.tightPredicate->serialize(true), fromjson("{meta: {$gt: 5}}")); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -675,7 +703,10 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, ->createPredicatesOnBucketLevelField(original->getMatchExpression()); // Meta predicates are mapped to the meta field, not the control min/max fields. - ASSERT_BSONOBJ_EQ(predicate->serialize(true), fromjson("{'meta.foo': {$gt: 5}}")); + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), + fromjson("{'meta.foo': {$gt: 5}}")); + ASSERT_BSONOBJ_EQ(predicate.tightPredicate->serialize(true), + fromjson("{'meta.foo': {$gt: 5}}")); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, @@ -693,7 +724,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$and: [" " {$or: [" " {'control.max.a': {$_internalExprGt: 1}}," @@ -704,6 +735,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, " ]}," " {meta: {$eq: 5}}" "]}")); + ASSERT(predicate.tightPredicate == nullptr); } TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePredicatesOnId) { @@ -740,7 +772,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get()); + auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get()); auto children = andExpr->getChildVector(); ASSERT_EQ(children->size(), 3); @@ -797,7 +829,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get()); + auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get()); auto children = andExpr->getChildVector(); ASSERT_EQ(children->size(), 3); @@ -846,7 +878,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get()); + auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get()); auto children = andExpr->getChildVector(); ASSERT_EQ(children->size(), 6); @@ -908,7 +940,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get()); + auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get()); auto children = andExpr->getChildVector(); ASSERT_EQ(children->size(), 3); @@ -957,7 +989,7 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, OptimizeMapsTimePre auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.get()); + auto andExpr = dynamic_cast<AndMatchExpression*>(predicate.loosePredicate.get()); auto children = andExpr->getChildVector(); ASSERT_EQ(children->size(), 3); @@ -1000,7 +1032,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_FALSE(predicate); + ASSERT_FALSE(predicate.loosePredicate); + ASSERT_FALSE(predicate.tightPredicate); } } { @@ -1021,7 +1054,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_FALSE(predicate); + ASSERT_FALSE(predicate.loosePredicate); + ASSERT_FALSE(predicate.tightPredicate); } } { @@ -1042,7 +1076,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_FALSE(predicate); + ASSERT_FALSE(predicate.loosePredicate); + ASSERT_FALSE(predicate.tightPredicate); } } { @@ -1065,7 +1100,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_FALSE(predicate); + ASSERT_FALSE(predicate.loosePredicate); + ASSERT_FALSE(predicate.tightPredicate); } } { @@ -1086,7 +1122,8 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_FALSE(predicate); + ASSERT_FALSE(predicate.loosePredicate); + ASSERT_FALSE(predicate.tightPredicate); } } } @@ -1107,10 +1144,11 @@ TEST_F(InternalUnpackBucketPredicateMappingOptimizationTest, auto predicate = dynamic_cast<DocumentSourceInternalUnpackBucket*>(container.front().get()) ->createPredicatesOnBucketLevelField(original->getMatchExpression()); - ASSERT_BSONOBJ_EQ(predicate->serialize(true), + ASSERT_BSONOBJ_EQ(predicate.loosePredicate->serialize(true), fromjson("{$_internalBucketGeoWithin: { withinRegion: { $geometry: { type : " "\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 " "] ] ]}},field: \"loc\"}}")); + ASSERT_FALSE(predicate.tightPredicate); } } // namespace } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp index 2001cb2cab3..aeb7ac8ec8c 100644 --- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp +++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/optimize_pipeline_test.cpp @@ -54,7 +54,7 @@ TEST_F(OptimizePipeline, MixedMatchPushedDown) { // To get the optimized $match from the pipeline, we have to serialize with explain. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(3u, stages.size()); + ASSERT_EQ(2u, stages.size()); // We should push down the $match on the metaField and the predicates on the control field. // The created $match stages should be added before $_internalUnpackBucket and merged. @@ -63,8 +63,10 @@ TEST_F(OptimizePipeline, MixedMatchPushedDown) { "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ] } ] } } ] }]}}"), stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson()); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), + stages[1].getDocument().toBson()); } TEST_F(OptimizePipeline, MetaMatchPushedDown) { @@ -103,7 +105,7 @@ TEST_F(OptimizePipeline, MixedMatchOr) { pipeline->optimizePipeline(); auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(3u, stages.size()); + ASSERT_EQ(2u, stages.size()); auto expected = fromjson( "{$match: {$and: [" // Result of pushing down {x: {$lte: 1}}. @@ -123,8 +125,11 @@ TEST_F(OptimizePipeline, MixedMatchOr) { " ]}" "]}}"); ASSERT_BSONOBJ_EQ(expected, stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(match, stages[2].getDocument().toBson()); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"foo\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { $and: [ { x: { $lte: 1 } }, { $or: [ { " + "\"myMeta.a\": { $gt: 1 } }, { y: { $lt: 1 } } ] } ] } } }"), + stages[1].getDocument().toBson()); } TEST_F(OptimizePipeline, MixedMatchOnlyMetaMatchPushedDown) { @@ -142,11 +147,13 @@ TEST_F(OptimizePipeline, MixedMatchOnlyMetaMatchPushedDown) { // We should push down the $match on the metaField but not the predicate on '$a', which is // ineligible because of the $type. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(3u, serialized.size()); + ASSERT_EQ(2u, serialized.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$gte: 0}}, {meta: {$lte: 5}}]}}"), serialized[0]); - ASSERT_BSONOBJ_EQ(unpack, serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$type: [ 2 ]}}}"), serialized[2]); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $type: [ 2 ] } } } }"), + serialized[1]); } TEST_F(OptimizePipeline, MultipleMatchesPushedDown) { @@ -164,15 +171,17 @@ TEST_F(OptimizePipeline, MultipleMatchesPushedDown) { // We should push down both the $match on the metaField and the predicates on the control field. // The created $match stages should be added before $_internalUnpackBucket and merged. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(3u, stages.size()); + ASSERT_EQ(2u, stages.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [ {meta: {$gte: 0}}," "{meta: {$lte: 5}}," "{$or: [ {'control.min.a': {$_internalExprLte: 4}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ]}," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"), stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson()); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), + stages[1].getDocument().toBson()); } TEST_F(OptimizePipeline, MultipleMatchesPushedDownWithSort) { @@ -191,16 +200,18 @@ TEST_F(OptimizePipeline, MultipleMatchesPushedDownWithSort) { // We should push down both the $match on the metaField and the predicates on the control field. // The created $match stages should be added before $_internalUnpackBucket and merged. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(4u, stages.size()); + ASSERT_EQ(3u, stages.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [ { meta: { $gte: 0 } }," "{meta: { $lte: 5 } }," "{$or: [ { 'control.min.a': { $_internalExprLte: 4 } }," "{$expr: { $ne: [ {$type: [ \"$control.min.a\" ] }," "{$type: [ \"$control.max.a\" ] } ] } } ] }]}}"), stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(unpack, stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$sort: {sortKey: {a: 1}}}"), stages[3].getDocument().toBson()); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), + stages[1].getDocument().toBson()); + ASSERT_BSONOBJ_EQ(fromjson("{$sort: {sortKey: {a: 1}}}"), stages[2].getDocument().toBson()); } TEST_F(OptimizePipeline, MetaMatchThenCountPushedDown) { @@ -261,7 +272,7 @@ TEST_F(OptimizePipeline, SortThenMixedMatchPushedDown) { // We should push down both the $sort and parts of the $match. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(4u, serialized.size()); + ASSERT_EQ(3u, serialized.size()); auto expected = fromjson( "{$match: {$and: [" " {meta: {$eq: 'abc'}}," @@ -274,8 +285,10 @@ TEST_F(OptimizePipeline, SortThenMixedMatchPushedDown) { "]}}"); ASSERT_BSONOBJ_EQ(expected, serialized[0]); ASSERT_BSONOBJ_EQ(fromjson("{$sort: {meta: -1}}"), serialized[1]); - ASSERT_BSONOBJ_EQ(unpack, serialized[2]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$gte: 5}}}"), serialized[3]); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $gte: 5 } } } }"), + serialized[2]); } TEST_F(OptimizePipeline, MetaMatchThenSortPushedDown) { @@ -331,18 +344,19 @@ TEST_F(OptimizePipeline, MixedMatchThenProjectPushedDown) { // We can push down part of the $match and use dependency analysis on the end of the pipeline. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(4u, stages.size()); + ASSERT_EQ(3u, stages.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$eq: 'abc'}}," "{$or: [ {'control.min.a': { $_internalExprLte: 4 } }," "{$expr: { $ne: [ {$type: [ \"$control.min.a\" ] }," "{$type: [ \"$control.max.a\" ] } ] } } ] } ]}}"), stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'x'], timeField: " - "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"), - stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson()); + ASSERT_BSONOBJ_EQ( + fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"x\" ], timeField: " + "\"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), + stages[1].getDocument().toBson()); ASSERT_BSONOBJ_EQ(fromjson("{$project: {_id: true, x: true}}"), - stages[3].getDocument().toBson()); + stages[2].getDocument().toBson()); } @@ -379,21 +393,21 @@ TEST_F(OptimizePipeline, ProjectThenMixedMatchPushedDown) { // We should push down part of the $match and do dependency analysis on the rest. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(4u, stages.size()); + ASSERT_EQ(3u, stages.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{meta: {$eq: \"abc\"}}," "{$or: [ {'control.min.a': {$_internalExprLte: 4}}," "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ] }," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"), stages[0].getDocument().toBson()); ASSERT_BSONOBJ_EQ( - fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'x', 'myMeta'], timeField: " - "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"), + fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"x\", \"myMeta\" ], " + "timeField: \"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), stages[2].getDocument().toBson()); const UnorderedFieldsBSONObjComparator kComparator; ASSERT_EQ( kComparator.compare(fromjson("{$project: {_id: true, a: true, myMeta: true, x: true}}"), - stages[3].getDocument().toBson()), + stages[2].getDocument().toBson()), 0); } @@ -410,7 +424,7 @@ TEST_F(OptimizePipeline, ProjectWithRenameThenMixedMatchPushedDown) { // We should push down part of the $match and do dependency analysis on the end of the pipeline. auto stages = pipeline->writeExplainOps(ExplainOptions::Verbosity::kQueryPlanner); - ASSERT_EQ(4u, stages.size()); + ASSERT_EQ(3u, stages.size()); ASSERT_BSONOBJ_EQ( fromjson("{$match: {$and: [{$or: [ {'control.max.y': {$_internalExprGte: \"abc\"}}," "{$expr: {$ne: [ {$type: [ \"$control.min.y\" ]}," @@ -419,13 +433,13 @@ TEST_F(OptimizePipeline, ProjectWithRenameThenMixedMatchPushedDown) { "{$expr: {$ne: [ {$type: [ \"$control.min.a\" ] }," "{$type: [ \"$control.max.a\" ]} ]}} ]} ]}}"), stages[0].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: { include: ['_id', 'a', 'y'], timeField: " - "'time', metaField: 'myMeta', bucketMaxSpanSeconds: 3600}}"), - stages[1].getDocument().toBson()); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [{y: {$gte: 'abc'}}, {a: {$lte: 4}}]}}"), - stages[2].getDocument().toBson()); + ASSERT_BSONOBJ_EQ( + fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"a\", \"y\" ], timeField: " + "\"time\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { $and: [ { y: { $gte: \"abc\" } }, { a: { $lte: 4 } } ] } } }"), + stages[1].getDocument().toBson()); ASSERT_BSONOBJ_EQ(fromjson("{$project: {_id: true, a: true, myMeta: '$y'}}"), - stages[3].getDocument().toBson()); + stages[2].getDocument().toBson()); } TEST_F(OptimizePipeline, ComputedProjectThenMetaMatchPushedDown) { @@ -466,15 +480,15 @@ TEST_F(OptimizePipeline, ComputedProjectThenMetaMatchNotPushedDown) { // We should both push down the project and internalize the remaining project, but we can't // push down the meta match due to the (now invalid) renaming. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(3u, serialized.size()); + ASSERT_EQ(2u, serialized.size()); ASSERT_BSONOBJ_EQ(fromjson("{$addFields: {myMeta: {$sum: ['$meta.a', '$meta.b']}}}"), serialized[0]); ASSERT_BSONOBJ_EQ( - fromjson( - "{$_internalUnpackBucket: { include: ['_id', 'myMeta'], timeField: 'time', metaField: " - "'myMeta', bucketMaxSpanSeconds: 3600, computedMetaProjFields: ['myMeta']}}"), + fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"myMeta\" ], timeField: " + "\"time\", metaField: \"myMeta\", " + "bucketMaxSpanSeconds: 3600, computedMetaProjFields: [ \"myMeta\" ], " + "eventFilter: { myMeta: { $gte: \"abc\" } } } }"), serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {myMeta: {$gte: 'abc'}}}"), serialized[2]); } // namespace TEST_F(OptimizePipeline, ComputedProjectThenMatchNotPushedDown) { @@ -491,13 +505,13 @@ TEST_F(OptimizePipeline, ComputedProjectThenMatchNotPushedDown) { // We should push down the computed project but not the match, because it depends on the newly // computed values. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(3u, serialized.size()); + ASSERT_EQ(2u, serialized.size()); ASSERT_BSONOBJ_EQ(fromjson("{$addFields: {y: {$sum: ['$meta.a', '$meta.b']}}}"), serialized[0]); - ASSERT_BSONOBJ_EQ( - fromjson("{$_internalUnpackBucket: { include: ['_id', 'y'], timeField: 'time', metaField: " - "'myMeta', bucketMaxSpanSeconds: 3600, computedMetaProjFields: ['y']}}"), - serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {y: {$gt: 'abc'}}}"), serialized[2]); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"y\" ], " + "timeField: \"time\", metaField: \"myMeta\", " + "bucketMaxSpanSeconds: 3600, computedMetaProjFields: [ \"y\" ], " + "eventFilter: { y: { $gt: \"abc\" } } } }"), + serialized[1]); } TEST_F(OptimizePipeline, MetaSortThenProjectPushedDown) { @@ -857,7 +871,7 @@ TEST_F(OptimizePipeline, MatchWithGeoWithinOnMeasurementsPushedDownUsingInternal pipeline->optimizePipeline(); auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(serialized.size(), 3U); + ASSERT_EQ(serialized.size(), 2U); // $match with $geoWithin on a non-metadata field is pushed down and $_internalBucketGeoWithin // is used. @@ -866,12 +880,12 @@ TEST_F(OptimizePipeline, MatchWithGeoWithinOnMeasurementsPushedDownUsingInternal "\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 " "] ] ]}},field: \"loc\"}}}"), serialized[0]); - ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: {exclude: [], timeField: " - "'time', bucketMaxSpanSeconds: 3600}}"), - serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {loc: {$geoWithin: {$geometry: {type: \"Polygon\", " - "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ]}}}}}"), - serialized[2]); + ASSERT_BSONOBJ_EQ( + fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "bucketMaxSpanSeconds: 3600, " + "eventFilter: { loc: { $geoWithin: { $geometry: { type: \"Polygon\", coordinates: " + "[ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ] } } } } } }"), + serialized[1]); } TEST_F(OptimizePipeline, MatchWithGeoWithinOnMetaFieldIsPushedDown) { @@ -913,7 +927,7 @@ TEST_F(OptimizePipeline, pipeline->optimizePipeline(); auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(serialized.size(), 3U); + ASSERT_EQ(serialized.size(), 2U); // $match with $geoIntersects on a non-metadata field is pushed down and // $_internalBucketGeoWithin is used. @@ -922,12 +936,12 @@ TEST_F(OptimizePipeline, "\"Polygon\" ,coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 " "] ] ]}},field: \"loc\"}}}"), serialized[0]); - ASSERT_BSONOBJ_EQ(fromjson("{$_internalUnpackBucket: {exclude: [], timeField: " - "'time', bucketMaxSpanSeconds: 3600}}"), - serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {loc: {$geoIntersects: {$geometry: {type: \"Polygon\", " - "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ]}}}}}"), - serialized[2]); + ASSERT_BSONOBJ_EQ( + fromjson("{ $_internalUnpackBucket: { exclude: [], timeField: \"time\", " + "bucketMaxSpanSeconds: 3600, " + "eventFilter: { loc: { $geoIntersects: { $geometry: { type: \"Polygon\", " + "coordinates: [ [ [ 0, 0 ], [ 3, 6 ], [ 6, 1 ], [ 0, 0 ] ] ] } } } } } }"), + serialized[1]); } TEST_F(OptimizePipeline, MatchWithGeoIntersectsOnMetaFieldIsPushedDown) { diff --git a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp index ba4f31adf17..4ce5d558ac4 100644 --- a/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp +++ b/src/mongo/db/pipeline/document_source_internal_unpack_bucket_test/split_match_on_meta_and_rename_test.cpp @@ -56,7 +56,7 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeSplitsMatchAndMaps // predicate on 'control.min.a'. These two created $match stages should be added before // $_internalUnpackBucket and merged. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(3u, serialized.size()); + ASSERT_EQ(2u, serialized.size()); ASSERT_BSONOBJ_EQ(fromjson("{$match: {$and: [" " {meta: {$gte: 0}}," " {meta: {$lte: 5}}," @@ -68,8 +68,13 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeSplitsMatchAndMaps " ]}" "]}}"), serialized[0]); - ASSERT_BSONOBJ_EQ(unpack, serialized[1]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {a: {$lte: 4}}}"), serialized[2]); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { " + "exclude: [], " + "timeField: \"foo\", " + "metaField: \"myMeta\", " + "bucketMaxSpanSeconds: 3600, " + "eventFilter: { a: { $lte: 4 } } } }"), + serialized[1]); } TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, OptimizeMovesMetaMatchBeforeUnpack) { @@ -94,10 +99,6 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, auto unpack = fromjson( "{$_internalUnpackBucket: { exclude: [], timeField: 'foo', metaField: 'myMeta', " "bucketMaxSpanSeconds: 3600}}"); - auto unpackExcluded = fromjson( - "{$_internalUnpackBucket: { include: ['_id', 'data'], timeField: 'foo', metaField: " - "'myMeta', " - "bucketMaxSpanSeconds: 3600}}"); auto pipeline = Pipeline::parse(makeVector(unpack, fromjson("{$project: {data: 1}}"), fromjson("{$match: {myMeta: {$gte: 0}}}")), @@ -108,9 +109,11 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, // The $match on meta is not moved before $_internalUnpackBucket since the field is excluded. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(2u, serialized.size()); - ASSERT_BSONOBJ_EQ(unpackExcluded, serialized[0]); - ASSERT_BSONOBJ_EQ(fromjson("{$match: {myMeta: {$gte: 0}}}"), serialized[1]); + ASSERT_EQ(1u, serialized.size()); + ASSERT_BSONOBJ_EQ(fromjson("{ $_internalUnpackBucket: { include: [ \"_id\", \"data\" ], " + "timeField: \"foo\", metaField: \"myMeta\", bucketMaxSpanSeconds: " + "3600, eventFilter: { myMeta: { $gte: 0 } } } }"), + serialized[0]); } TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, @@ -134,7 +137,7 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, // We should fail to split the match because of the $or clause. We should still be able to // map the predicate on 'x' to a predicate on the control field. auto serialized = pipeline->serializeToBson(); - ASSERT_EQ(3u, serialized.size()); + ASSERT_EQ(2u, serialized.size()); auto expected = fromjson( "{$match: {$and: [" // Result of pushing down {x: {$lte: 1}}. @@ -154,8 +157,13 @@ TEST_F(InternalUnpackBucketSplitMatchOnMetaAndRename, " ]}" "]}}"); ASSERT_BSONOBJ_EQ(expected, serialized[0]); - ASSERT_BSONOBJ_EQ(unpack, serialized[1]); - ASSERT_BSONOBJ_EQ(match, serialized[2]); + ASSERT_BSONOBJ_EQ( + fromjson( + "{ $_internalUnpackBucket: { " + "exclude: [], timeField: \"foo\", metaField: \"myMeta\", bucketMaxSpanSeconds: 3600, " + "eventFilter: { $and: [ { x: { $lte: 1 } }, { $or: [ { \"myMeta.a\": { $gt: 1 } }, { " + "y: { $lt: 1 } } ] } ] } } }"), + serialized[1]); } } // namespace } // namespace mongo diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index e87a39f85ac..080faffe5c3 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2414,7 +2414,7 @@ intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::createVarFromString( ExpressionFieldPath::ExpressionFieldPath(ExpressionContext* const expCtx, const string& theFieldPath, Variables::Id variable) - : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) { + : Expression(expCtx), _fieldPath(theFieldPath, true /*precomputeHashes*/), _variable(variable) { const auto varName = theFieldPath.substr(0, theFieldPath.find('.')); tassert(5943201, std::string{"Variable with $$ROOT's id is not $$CURRENT or $$ROOT as expected, " diff --git a/src/mongo/db/pipeline/field_path.cpp b/src/mongo/db/pipeline/field_path.cpp index c860e863512..bdbbbabfee0 100644 --- a/src/mongo/db/pipeline/field_path.cpp +++ b/src/mongo/db/pipeline/field_path.cpp @@ -69,10 +69,8 @@ string FieldPath::getFullyQualifiedPath(StringData prefix, StringData suffix) { return str::stream() << prefix << "." << suffix; } -FieldPath::FieldPath(std::string inputPath) - : _fieldPath(std::move(inputPath)), - _fieldPathDotPosition{string::npos}, - _fieldHash{kHashUninitialized} { +FieldPath::FieldPath(std::string inputPath, bool precomputeHashes) + : _fieldPath(std::move(inputPath)), _fieldPathDotPosition{string::npos} { uassert(40352, "FieldPath cannot be constructed with empty string", !_fieldPath.empty()); uassert(40353, "FieldPath must not end with a '.'.", _fieldPath[_fieldPath.size() - 1] != '.'); @@ -81,19 +79,21 @@ FieldPath::FieldPath(std::string inputPath) size_t startPos = 0; while (string::npos != (dotPos = _fieldPath.find('.', startPos))) { _fieldPathDotPosition.push_back(dotPos); - _fieldHash.push_back(kHashUninitialized); startPos = dotPos + 1; } _fieldPathDotPosition.push_back(_fieldPath.size()); - // Validate the path length and the fields. + // Validate the path length and the fields, and precompute their hashes if requested. const auto pathLength = getPathLength(); uassert(ErrorCodes::Overflow, "FieldPath is too long", pathLength <= BSONDepth::getMaxAllowableDepth()); + _fieldHash.reserve(pathLength); for (size_t i = 0; i < pathLength; ++i) { - uassertValidFieldName(getFieldName(i)); + const auto& fieldName = getFieldName(i); + uassertValidFieldName(fieldName); + _fieldHash.push_back(precomputeHashes ? FieldNameHasher()(fieldName) : kHashUninitialized); } } diff --git a/src/mongo/db/pipeline/field_path.h b/src/mongo/db/pipeline/field_path.h index d2ee93734e7..c216bdbc0a4 100644 --- a/src/mongo/db/pipeline/field_path.h +++ b/src/mongo/db/pipeline/field_path.h @@ -69,9 +69,11 @@ public: * * Field names are validated using uassertValidFieldName(). */ - /* implicit */ FieldPath(std::string inputPath); - /* implicit */ FieldPath(StringData inputPath) : FieldPath(inputPath.toString()) {} - /* implicit */ FieldPath(const char* inputPath) : FieldPath(std::string(inputPath)) {} + /* implicit */ FieldPath(std::string inputPath, bool precomputeHashes = false); + /* implicit */ FieldPath(StringData inputPath, bool precomputeHashes = false) + : FieldPath(inputPath.toString(), precomputeHashes) {} + /* implicit */ FieldPath(const char* inputPath, bool precomputeHashes = false) + : FieldPath(std::string(inputPath), precomputeHashes) {} /** * Returns the number of path elements in the field path. @@ -117,13 +119,8 @@ public: */ HashedFieldName getFieldNameHashed(size_t i) const { dassert(i < getPathLength()); - const auto begin = _fieldPathDotPosition[i] + 1; - const auto end = _fieldPathDotPosition[i + 1]; - StringData fieldName{&_fieldPath[begin], end - begin}; - if (_fieldHash[i] == kHashUninitialized) { - _fieldHash[i] = FieldNameHasher()(fieldName); - } - return HashedFieldName{fieldName, _fieldHash[i]}; + invariant(_fieldHash[i] != kHashUninitialized); + return HashedFieldName{getFieldName(i), _fieldHash[i]}; } /** @@ -177,9 +174,9 @@ private: // lookup. std::vector<size_t> _fieldPathDotPosition; - // Contains the cached hash value for the field. Will initially be set to 'kHashUninitialized', - // and only generated when it is first retrieved via 'getFieldNameHashed'. - mutable std::vector<size_t> _fieldHash; + // Contains the hash value for the field names if it was requested when creating this path. + // Otherwise all elements are set to 'kHashUninitialized'. + std::vector<size_t> _fieldHash; static constexpr std::size_t kHashUninitialized = std::numeric_limits<std::size_t>::max(); }; diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index ac2173ae50c..ec84917b5c5 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -107,7 +107,7 @@ using write_ops::InsertCommandRequest; namespace { /** - * Extracts a prefix of 'DocumentSourceGroup' and 'DocumentSourceLookUp' stages from the given + * Finds a prefix of 'DocumentSourceGroup' and 'DocumentSourceLookUp' stages from the given * pipeline to prepare for pushdown of $group and $lookup into the inner query layer so that it * can be executed using SBE. * Group stages are extracted from the pipeline when all of the following conditions are met: @@ -121,11 +121,11 @@ namespace { * - The $lookup uses only the 'localField'/'foreignField' syntax (no pipelines). * - The foreign collection is neither sharded nor a view. */ -std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleStagesForPushdown( +std::vector<std::unique_ptr<InnerPipelineStageInterface>> findSbeCompatibleStagesForPushdown( const intrusive_ptr<ExpressionContext>& expCtx, const MultipleCollectionAccessor& collections, const CanonicalQuery* cq, - Pipeline* pipeline) { + const Pipeline* pipeline) { // We will eventually use the extracted group stages to populate 'CanonicalQuery::pipeline' // which requires stages to be wrapped in an interface. std::vector<std::unique_ptr<InnerPipelineStageInterface>> stagesForPushdown; @@ -140,7 +140,7 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt return {}; } - auto&& sources = pipeline->getSources(); + const auto& sources = pipeline->getSources(); bool isMainCollectionSharded = false; if (const auto& mainColl = collections.getMainCollection()) { @@ -158,7 +158,7 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt internalQuerySlotBasedExecutionDisableLookupPushdown.load() || isMainCollectionSharded || collections.isAnySecondaryNamespaceAViewOrSharded(); - for (auto itr = sources.begin(); itr != sources.end();) { + for (auto itr = sources.begin(); itr != sources.end(); ++itr) { const bool isLastSource = itr->get() == sources.back().get(); // $group pushdown logic. @@ -170,7 +170,6 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt if (groupStage->sbeCompatible() && !groupStage->doingMerge()) { stagesForPushdown.push_back( std::make_unique<InnerPipelineStageImpl>(groupStage, isLastSource)); - sources.erase(itr++); continue; } break; @@ -187,7 +186,6 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt if (lookupStage->sbeCompatible()) { stagesForPushdown.push_back( std::make_unique<InnerPipelineStageImpl>(lookupStage, isLastSource)); - sources.erase(itr++); continue; } break; @@ -199,6 +197,21 @@ std::vector<std::unique_ptr<InnerPipelineStageInterface>> extractSbeCompatibleSt return stagesForPushdown; } +/** + * Removes the first 'stagesToRemove' stages from the pipeline. This function is meant to be paired + * with a call to findSbeCompatibleStagesForPushdown() - the caller must first get the stages to + * push down, then remove them. + */ +void trimPipelineStages(Pipeline* pipeline, size_t stagesToRemove) { + auto& sources = pipeline->getSources(); + tassert(7087104, + "stagesToRemove must be <= number of pipeline sources", + stagesToRemove <= sources.size()); + for (size_t i = 0; i < stagesToRemove; ++i) { + sources.erase(sources.begin()); + } +} + StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> attemptToGetExecutor( const intrusive_ptr<ExpressionContext>& expCtx, const MultipleCollectionAccessor& collections, @@ -296,9 +309,15 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> attemptToGetExe return getExecutorFind(expCtx->opCtx, collections, std::move(cq.getValue()), - [&](auto* canonicalQuery) { - canonicalQuery->setPipeline(extractSbeCompatibleStagesForPushdown( - expCtx, collections, canonicalQuery, pipeline)); + [&](auto* canonicalQuery, bool attachOnly) { + if (attachOnly) { + canonicalQuery->setPipeline(findSbeCompatibleStagesForPushdown( + expCtx, collections, canonicalQuery, pipeline)); + } else { + // Not attaching - we need to trim the already pushed down + // pipeline stages from the pipeline. + trimPipelineStages(pipeline, canonicalQuery->pipeline().size()); + } }, permitYield, plannerOpts); diff --git a/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp b/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp index 0731249fa29..e8e109aee8f 100644 --- a/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp +++ b/src/mongo/db/pipeline/process_interface/common_mongod_process_interface.cpp @@ -422,9 +422,8 @@ CommonMongodProcessInterface::attachCursorSourceToPipelineForLocalRead(Pipeline* } boost::optional<AutoGetCollectionForReadCommandMaybeLockFree> autoColl; - const NamespaceStringOrUUID nsOrUUID = expCtx->uuid - ? NamespaceStringOrUUID{expCtx->ns.db().toString(), *expCtx->uuid} - : expCtx->ns; + const NamespaceStringOrUUID nsOrUUID = + expCtx->uuid ? NamespaceStringOrUUID{expCtx->ns.dbName(), *expCtx->uuid} : expCtx->ns; // Reparse 'pipeline' to discover whether there are secondary namespaces that we need to lock // when constructing our query executor. diff --git a/src/mongo/db/prepare_conflict_tracker.cpp b/src/mongo/db/prepare_conflict_tracker.cpp index e8e031a1405..ac9c1b495fc 100644 --- a/src/mongo/db/prepare_conflict_tracker.cpp +++ b/src/mongo/db/prepare_conflict_tracker.cpp @@ -52,7 +52,7 @@ void PrepareConflictTracker::endPrepareConflict(OperationContext* opCtx) { auto curTick = tickSource->getTicks(); auto curConflictDuration = - tickSource->spanTo<Microseconds>(_prepareConflictStartTime, curTick); + tickSource->ticksTo<Microseconds>(curTick - _prepareConflictStartTime); _prepareConflictDuration.store(_prepareConflictDuration.load() + curConflictDuration); _prepareConflictStartTime = 0; diff --git a/src/mongo/db/query/canonical_query.h b/src/mongo/db/query/canonical_query.h index e939c248a9d..fe0599a1b6a 100644 --- a/src/mongo/db/query/canonical_query.h +++ b/src/mongo/db/query/canonical_query.h @@ -233,6 +233,14 @@ public: return _sbeCompatible; } + void setUseCqfIfEligible(bool useCqfIfEligible) { + _useCqfIfEligible = useCqfIfEligible; + } + + bool useCqfIfEligible() const { + return _useCqfIfEligible; + } + bool isParameterized() const { return !_inputParamIdToExpressionMap.empty(); } @@ -318,6 +326,16 @@ private: // True if this query can be executed by the SBE. bool _sbeCompatible = false; + // If true, indicates that we should use CQF if this query is eligible (see the + // isEligibleForBonsai() function for eligiblitly requirements). + // If false, indicates that we shouldn't use CQF even if this query is eligible. This is used to + // prevent hybrid classic and CQF plans in the following cases: + // 1. A pipeline that is not eligible for CQF but has an eligible prefix pushed down to find. + // 2. A subpipeline pushed down to find as part of a $lookup or $graphLookup. + // The default value of false ensures that only codepaths (find command) which opt-in are able + // to use CQF. + bool _useCqfIfEligible = false; + // True if this query must produce a RecordId output in addition to the BSON objects that // constitute the result set of the query. Any generated query solution must not discard record // ids, even if the optimizer detects that they are not going to be consumed downstream. diff --git a/src/mongo/db/query/ce/ce_heuristic_test.cpp b/src/mongo/db/query/ce/ce_heuristic_test.cpp index d1549a88daa..1d5cec6b8fc 100644 --- a/src/mongo/db/query/ce/ce_heuristic_test.cpp +++ b/src/mongo/db/query/ce/ce_heuristic_test.cpp @@ -31,7 +31,6 @@ #include "mongo/db/query/ce/ce_test_utils.h" #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" #include "mongo/db/query/optimizer/cascades/memo.h" #include "mongo/db/query/optimizer/defs.h" diff --git a/src/mongo/db/query/ce/ce_test_utils.cpp b/src/mongo/db/query/ce/ce_test_utils.cpp index 349baadd549..32d881e308b 100644 --- a/src/mongo/db/query/ce/ce_test_utils.cpp +++ b/src/mongo/db/query/ce/ce_test_utils.cpp @@ -32,12 +32,12 @@ #include "mongo/db/query/ce/ce_test_utils.h" #include "mongo/db/pipeline/abt/utils.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/metadata_factory.h" #include "mongo/db/query/optimizer/opt_phase_manager.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" #include "mongo/db/query/optimizer/utils/unit_test_pipeline_utils.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/db/query/sbe_stage_builder_helpers.h" #include "mongo/unittest/unittest.h" @@ -81,7 +81,7 @@ optimizer::CEType CETester::getCE(ABT& abt, std::function<bool(const ABT&)> node false /*requireRID*/, _metadata, getCETransport(), - std::make_unique<DefaultCosting>(), + makeCosting(), defaultConvertPathToInterval, ConstEval::constFold, DebugInfo::kDefaultForTests, diff --git a/src/mongo/db/query/count_command.idl b/src/mongo/db/query/count_command.idl index d511cd879fb..0ddb5887b0d 100644 --- a/src/mongo/db/query/count_command.idl +++ b/src/mongo/db/query/count_command.idl @@ -130,3 +130,8 @@ commands: description: "Indicates whether the operation is a mirrored read" type: optionalBool stability: unstable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable diff --git a/src/mongo/db/query/cqf_command_utils.cpp b/src/mongo/db/query/cqf_command_utils.cpp index 8863a22adde..98e3ec34a13 100644 --- a/src/mongo/db/query/cqf_command_utils.cpp +++ b/src/mongo/db/query/cqf_command_utils.cpp @@ -737,7 +737,7 @@ bool isEligibleForBonsai(const CanonicalQuery& cq, !request.getReadOnce() && !request.getShowRecordId() && !request.getTerm(); // Early return to avoid unnecessary work of walking the input expression. - if (!commandOptionsEligible) { + if (!commandOptionsEligible || !cq.useCqfIfEligible()) { return false; } diff --git a/src/mongo/db/query/cqf_get_executor.cpp b/src/mongo/db/query/cqf_get_executor.cpp index 3508553b591..cc1c42bc3af 100644 --- a/src/mongo/db/query/cqf_get_executor.cpp +++ b/src/mongo/db/query/cqf_get_executor.cpp @@ -56,9 +56,12 @@ #include "mongo/db/query/yield_policy_callbacks_impl.h" #include "mongo/logv2/log.h" #include "mongo/logv2/log_attr.h" +#include "mongo/util/fail_point.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery +MONGO_FAIL_POINT_DEFINE(failConstructingBonsaiExecutor); + namespace mongo { using namespace optimizer; using cost_model::CostEstimator; @@ -638,6 +641,9 @@ std::unique_ptr<PlanExecutor, PlanExecutor::Deleter> getSBEExecutorViaCascadesOp const boost::optional<BSONObj>& indexHint, std::unique_ptr<Pipeline, PipelineDeleter> pipeline, std::unique_ptr<CanonicalQuery> canonicalQuery) { + if (MONGO_unlikely(failConstructingBonsaiExecutor.shouldFail())) { + uasserted(620340, "attempting to use CQF while it is disabled"); + } // Ensure that either pipeline or canonicalQuery is set. tassert(624070, "getSBEExecutorViaCascadesOptimizer expects exactly one of the following to be set: " diff --git a/src/mongo/db/query/distinct_command.idl b/src/mongo/db/query/distinct_command.idl index a632ef06ffa..7dbb2624125 100644 --- a/src/mongo/db/query/distinct_command.idl +++ b/src/mongo/db/query/distinct_command.idl @@ -45,7 +45,7 @@ commands: description: "The field path for which to return distinct values." type: string query: - description: "Optional query that filters the documents from which to retrieve the + description: "Optional query that filters the documents from which to retrieve the distinct values." type: object optional: true @@ -56,3 +56,7 @@ commands: mirrored: description: "Indicates whether the operation is a mirrored read" type: optionalBool + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true diff --git a/src/mongo/db/query/find_command.idl b/src/mongo/db/query/find_command.idl index 44e7b02ceb9..c7a157dc810 100644 --- a/src/mongo/db/query/find_command.idl +++ b/src/mongo/db/query/find_command.idl @@ -243,4 +243,9 @@ commands: description: "Indicates whether the operation is a mirrored read" type: optionalBool stability: unstable + sampleId: + description: "The unique sample id for the operation if it has been chosen for sampling." + type: uuid + optional: true + stability: unstable diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h index 5fcac966dbf..fc78aa28014 100644 --- a/src/mongo/db/query/fle/encrypted_predicate.h +++ b/src/mongo/db/query/fle/encrypted_predicate.h @@ -40,6 +40,7 @@ #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/stdx/unordered_map.h" /** * This file contains an abstract class that describes rewrites on agg Expressions and @@ -194,33 +195,37 @@ private: * are keyed on the dynamic type for the Expression subclass. */ -using ExpressionToRewriteMap = stdx::unordered_map< - std::type_index, - std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>; +using ExpressionRewriteFunction = + std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>; +using ExpressionToRewriteMap = + stdx::unordered_map<std::type_index, std::vector<ExpressionRewriteFunction>>; extern ExpressionToRewriteMap aggPredicateRewriteMap; -using MatchTypeToRewriteMap = stdx::unordered_map< - MatchExpression::MatchType, - std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>; +using MatchRewriteFunction = + std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>; +using MatchTypeToRewriteMap = + stdx::unordered_map<MatchExpression::MatchType, std::vector<MatchRewriteFunction>>; extern MatchTypeToRewriteMap matchPredicateRewriteMap; /** * Register an agg rewrite if a condition is true at startup time. */ -#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \ - \ - invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \ - aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<Expression>(nullptr); \ - } \ - }; \ - } +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ + MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className##_##rewriteClass) \ + (InitializerContext*) { \ + if (aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()) { \ + aggPredicateRewriteMap[typeid(className)] = std::vector<ExpressionRewriteFunction>(); \ + } \ + aggPredicateRewriteMap[typeid(className)].push_back([](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<Expression>(nullptr); \ + } \ + }); \ + }; /** * Register an agg rewrite unconditionally. @@ -239,17 +244,21 @@ extern MatchTypeToRewriteMap matchPredicateRewriteMap; * Register a MatchExpression rewrite if a condition is true at startup time. */ #define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \ - \ - invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \ - matchPredicateRewriteMap.end()); \ - matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<MatchExpression>(nullptr); \ - } \ - }; \ + MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType##_##rewriteClass) \ + (InitializerContext*) { \ + if (matchPredicateRewriteMap.find(MatchExpression::matchType) == \ + matchPredicateRewriteMap.end()) { \ + matchPredicateRewriteMap[MatchExpression::matchType] = \ + std::vector<MatchRewriteFunction>(); \ + } \ + matchPredicateRewriteMap[MatchExpression::matchType].push_back( \ + [](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<MatchExpression>(nullptr); \ + } \ + }); \ }; /** * Register a MatchExpression rewrite unconditionally. diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp index 441f436ec00..da4fb05ecc7 100644 --- a/src/mongo/db/query/fle/query_rewriter.cpp +++ b/src/mongo/db/query/fle/query_rewriter.cpp @@ -40,12 +40,15 @@ public: : queryRewriter(queryRewriter), exprRewrites(exprRewrites){}; std::unique_ptr<Expression> postVisit(Expression* exp) { - if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) { - auto expr = rewrite->second(queryRewriter, exp); - if (expr != nullptr) { - didRewrite = true; + if (auto rewriteEntry = exprRewrites.find(typeid(*exp)); + rewriteEntry != exprRewrites.end()) { + for (auto& rewrite : rewriteEntry->second) { + auto expr = rewrite(queryRewriter, exp); + if (expr != nullptr) { + didRewrite = true; + return expr; + } } - return expr; } return nullptr; } @@ -109,13 +112,17 @@ std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr) return nullptr; } default: { - if (auto rewrite = _matchRewrites.find(expr->matchType()); - rewrite != _matchRewrites.end()) { - auto rewritten = rewrite->second(this, expr); - if (rewritten != nullptr) { - _rewroteLastExpression = true; + if (auto rewriteEntry = _matchRewrites.find(expr->matchType()); + rewriteEntry != _matchRewrites.end()) { + for (auto& rewrite : rewriteEntry->second) { + auto rewritten = rewrite(this, expr); + // Only one rewrite can be applied to an expression, so return as soon as a + // rewrite returns something other than nullptr. + if (rewritten != nullptr) { + _rewroteLastExpression = true; + return rewritten; + } } - return rewritten; } return nullptr; } diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp index d779f6be1cc..057bae0dd2b 100644 --- a/src/mongo/db/query/fle/query_rewriter_test.cpp +++ b/src/mongo/db/query/fle/query_rewriter_test.cpp @@ -71,7 +71,7 @@ protected: if (!elt.isABSONObj()) { return false; } - return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd; + return elt.Obj().hasField("encrypt"_sd); } bool isPayload(const Value& v) const override { if (!v.isObject()) { @@ -97,9 +97,110 @@ protected: }; std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override { + auto eqMatch = dynamic_cast<ExpressionCompare*>(expr); + invariant(eqMatch); + // Only operate over equality comparisons. + if (eqMatch->getOp() != ExpressionCompare::EQ) { + return nullptr; + } + auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get()); + // If the comparison doesn't hold a constant, then don't rewrite. + if (!payload) { + return nullptr; + } + + // If the constant is not considered a payload, then don't rewrite. + if (!isPayload(payload->getValue())) { + return nullptr; + } + auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(), + ExpressionCompare::GT); + cmp->addOperand(eqMatch->getOperandList()[0]); + cmp->addOperand( + ExpressionConstant::create(eqMatch->getExpressionContext(), + payload->getValue().getDocument().getField("encrypt"))); + return cmp; + } + + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const override { + return nullptr; + } + + std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override { return nullptr; } +private: + // This method is not used in mock implementations of the EncryptedPredicate since isPayload(), + // which normally calls encryptedBinDataType(), is overridden to look for plain objects rather + // than BinData. Since this method is pure virtual on the superclass and needs to be + // implemented, it is set to kPlaceholder (0). + EncryptedBinDataType encryptedBinDataType() const override { + return EncryptedBinDataType::kPlaceholder; + } +}; + +// A second mock rewrite which replaces documents with the key "foo" into $lt operations. We need +// two different rewrites that are registered on the same operator to verify that all rewrites are +// iterated through. +class OtherMockPredicateRewriter : public fle::EncryptedPredicate { +public: + OtherMockPredicateRewriter(const fle::QueryRewriterInterface* rewriter) + : EncryptedPredicate(rewriter) {} + +protected: + bool isPayload(const BSONElement& elt) const override { + if (!elt.isABSONObj()) { + return false; + } + return elt.Obj().hasField("foo"_sd); + } + bool isPayload(const Value& v) const override { + if (!v.isObject()) { + return false; + } + return !v.getDocument().getField("foo").missing(); + } + + std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override { + return {}; + }; + + // Encrypted values will be rewritten from $eq to $lt. This is an arbitrary decision just to + // make sure that the rewrite works properly. + std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override { + invariant(expr->matchType() == MatchExpression::EQ); + auto eqMatch = static_cast<EqualityMatchExpression*>(expr); + if (!isPayload(eqMatch->getData())) { + return nullptr; + } + return std::make_unique<LTMatchExpression>(eqMatch->path(), + eqMatch->getData().Obj().firstElement()); + }; + + std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override { + auto eqMatch = dynamic_cast<ExpressionCompare*>(expr); + invariant(eqMatch); + if (eqMatch->getOp() != ExpressionCompare::EQ) { + return nullptr; + } + auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get()); + if (!payload) { + return nullptr; + } + + if (!isPayload(payload->getValue())) { + return nullptr; + } + auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(), + ExpressionCompare::LT); + cmp->addOperand(eqMatch->getOperandList()[0]); + cmp->addOperand(ExpressionConstant::create( + eqMatch->getExpressionContext(), payload->getValue().getDocument().getField("foo"))); + return cmp; + } + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( MatchExpression* expr) const override { return nullptr; @@ -111,7 +212,7 @@ protected: private: EncryptedBinDataType encryptedBinDataType() const override { - return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere. + return EncryptedBinDataType::kPlaceholder; } }; @@ -119,8 +220,17 @@ void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match, fle::ExpressionToRewriteMap& agg, fle::TagMap& tags, std::set<StringData>& encryptedFields) { - match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) { - return MockPredicateRewriter{rewriter}.rewrite(expr); + match[MatchExpression::EQ] = { + [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); }, + [&](auto* rewriter, auto* expr) { + return OtherMockPredicateRewriter{rewriter}.rewrite(expr); + }, + }; + agg[typeid(ExpressionCompare)] = { + [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); }, + [&](auto* rewriter, auto* expr) { + return OtherMockPredicateRewriter{rewriter}.rewrite(expr); + }, }; } @@ -137,9 +247,17 @@ public: return res ? res.value() : obj; } + BSONObj rewriteAggExpressionForTest(const BSONObj& obj) { + auto expr = Expression::parseExpression(&_expCtx, obj, _expCtx.variablesParseState); + auto result = rewriteExpression(expr.get()); + return result ? result->serialize(false).getDocument().toBson() + : expr->serialize(false).getDocument().toBson(); + } + private: fle::TagMap _tags; std::set<StringData> _encryptedFields; + ExpressionContextForTest _expCtx; }; class FLEServerRewriteTest : public unittest::Test { @@ -167,22 +285,53 @@ protected: ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \ } +#define ASSERT_AGG_EXPRESSION_REWRITE(input, expected) \ + auto actual = _mock->rewriteAggExpressionForTest(fromjson(input)); \ + ASSERT_BSONOBJ_EQ(actual, fromjson(expected)); + +#define TEST_FLE_REWRITE_AGG(name, input, expected) \ + TEST_F(FLEServerRewriteTest, name##_AggExpression) { \ + ASSERT_AGG_EXPRESSION_REWRITE(input, expected); \ + } + TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath, "{'user.ssn': {$eq: {encrypt: 2}}}", "{'user.ssn': {$gt: 2}}"); +TEST_FLE_REWRITE_AGG(TopLevel_DottedPath, + "{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}", + "{$gt: ['$user.ssn', {$const: 2}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted, "{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_BothEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {encrypt: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$gt: ['$age', {$const: 4}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted, "{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_PartlyEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator, "{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$lt: ['$age', {$const: {encrypt: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: {encrypt: " + "4}}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted, "{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}", "{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}"); @@ -191,6 +340,10 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Not, "{ssn: {$not: {$eq: {encrypt: 5}}}}", "{ssn: {$not: {$gt: 5}}}"); +TEST_FLE_REWRITE_AGG(TopLevel_Not, + "{$not: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}]}", + "{$not: [{$gt: ['$ssn', {$const: 2}]}]}") + TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}"); TEST_FLE_REWRITE_MATCH( @@ -198,6 +351,12 @@ TEST_FLE_REWRITE_MATCH( "{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}", "{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}"); +TEST_FLE_REWRITE_AGG(NestedConjunction, + "{$and: [{$and: [{$eq: ['$ssn', {$const: {encrypt: 2}}]},{$eq: ['$other', " + "'field']}]},{$eq: ['$age',{$const: {encrypt: 4}}]}]}", + "{$and: [{$and: [{$gt: ['$ssn', {$const: 2}]},{$eq: ['$other', " + "{$const: 'field'}]}]},{$gt: ['$age',{$const: 4}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Nor, "{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", "{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); @@ -205,5 +364,30 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Nor, TEST_FLE_REWRITE_MATCH(TopLevel_Or, "{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", "{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); + +TEST_FLE_REWRITE_AGG( + TopLevel_Or, + "{$or: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}, {$eq: ['$ssn', {$const: {encrypt: 4}}]}]}", + "{$or: [{$gt: ['$ssn', {$const: 2}]}, {$gt: ['$ssn', {$const: 4}]}]}") + + +// Test that the rewriter will work from any rewrite registered to an expression. The test rewriter +// has two rewrites registered on $eq. + +TEST_FLE_REWRITE_MATCH(OtherRewrite_Basic, "{'ssn': {$eq: {foo: 2}}}", "{'ssn': {$lt: 2}}"); + +TEST_FLE_REWRITE_AGG(OtherRewrite_Basic, + "{$eq: ['$user.ssn', {$const: {foo: 2}}]}", + "{$lt: ['$user.ssn', {$const: 2}]}"); + +TEST_FLE_REWRITE_MATCH(OtherRewrite_Conjunction_BothEncrypted, + "{$and: [{ssn: {encrypt: 2}}, {age: {foo: 4}}]}", + "{$and: [{ssn: {$gt: 2}}, {age: {$lt: 4}}]}"); + +TEST_FLE_REWRITE_AGG( + OtherRewrite_Conjunction_BothEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {foo: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: 4}]}]}"); } // namespace } // namespace mongo diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index 8c083b07dc1..dcd71777782 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare, + RangePredicate, + gFeatureFlagFLE2Range); + +namespace { +// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If +// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range +// predicate, then return null to the caller so that the rewrite can return null. +std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) { + auto children = [&]() { + if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) { + return betweenExpr->getChildren(); + } else { + auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr); + tassert(6720901, + "Range rewrite should only be called with $between or comparison operator.", + cmpExpr); + switch (cmpExpr->getOp()) { + case ExpressionCompare::GT: + case ExpressionCompare::GTE: + case ExpressionCompare::LT: + case ExpressionCompare::LTE: + return cmpExpr->getChildren(); + + case ExpressionCompare::EQ: + case ExpressionCompare::NE: + case ExpressionCompare::CMP: + return std::vector<boost::intrusive_ptr<Expression>>(); + } + } + return std::vector<boost::intrusive_ptr<Expression>>(); + }(); + if (children.empty()) { + return {nullptr, Value()}; + } + // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2. + auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); + uassert(6720903, "first argument should be a fieldpath", fieldpath); + auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); + uassert(6720904, "second argument should be a constant", secondArg); + auto payload = secondArg->getValue(); + return {children[0], payload}; +} +} // namespace + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + StringData path, ParsedFindRangePayload payload) const { + auto* expCtx = _rewriter->getExpressionContext(); + return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + payload); +} + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { + tassert(7030501, + "$internalFleBetween can only be generated from a non-stub payload.", + !payload.isStub()); + auto cm = payload.maxCounter; + ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); + std::vector<ConstDataRange> edcTokens; + std::transform(std::make_move_iterator(payload.edges.value().begin()), + std::make_move_iterator(payload.edges.value().end()), + std::back_inserter(edcTokens), + [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); + + auto* expCtx = _rewriter->getExpressionContext(); + return std::make_unique<ExpressionInternalFLEBetween>( + expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); +} + std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload); std::vector<PrfBlock> tags; @@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( return makeTagDisjunction(toBSONArray(generateTags(payload))); } -std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) { - auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr); - tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr); - auto children = betweenExpr->getChildren(); - uassert(6720902, "$between should have two children.", children.size() == 2); - - auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); - uassert(6720903, "first argument should be a fieldpath", fieldpath); - auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); - uassert(6720904, "second argument should be a constant", secondArg); - auto payload = secondArg->getValue(); - return {children[0], payload}; -} - std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { - auto [_, payload] = validateBetween(expr); + auto [fieldpath, payload] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(payload)) { return nullptr; } + if (isStub(std::ref(payload))) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } + auto tags = toValues(generateTags(std::ref(payload))); return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); } -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - StringData path, ParsedFindRangePayload payload) const { - auto* expCtx = _rewriter->getExpressionContext(); - return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( - expCtx, path.toString(), expCtx->variablesParseState), - payload); -} - -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { - tassert(7030501, - "$internalFleBetween can only be generated from a non-stub payload.", - !payload.isStub()); - auto cm = payload.maxCounter; - ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); - std::vector<ConstDataRange> edcTokens; - std::transform(std::make_move_iterator(payload.edges.value().begin()), - std::make_move_iterator(payload.edges.value().end()), - std::back_inserter(edcTokens), - [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); - - auto* expCtx = _rewriter->getExpressionContext(); - return std::make_unique<ExpressionInternalFLEBetween>( - expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); -} - std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( MatchExpression* expr) const { BSONElement ffp; @@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( } std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const { - auto [fieldpath, ffp] = validateBetween(expr); + auto [fieldpath, ffp] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(ffp)) { return nullptr; } auto payload = parseFindPayload<ParsedFindRangePayload>(ffp); + if (payload.isStub()) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } return fleBetweenFromPayload(fieldpath, payload); } } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h index bc2a47fa173..5fbd7484ad0 100644 --- a/src/mongo/db/query/fle/range_predicate.h +++ b/src/mongo/db/query/fle/range_predicate.h @@ -55,6 +55,12 @@ protected: return parsedPayload.isStub(); } + virtual bool isStub(Value elt) const { + auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(elt); + return parsedPayload.isStub(); + } + + private: EncryptedBinDataType encryptedBinDataType() const override { return EncryptedBinDataType::kFLE2FindRangePayload; diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp index f81723da5eb..61110c19041 100644 --- a/src/mongo/db/query/fle/range_predicate_test.cpp +++ b/src/mongo/db/query/fle/range_predicate_test.cpp @@ -68,7 +68,11 @@ protected: return isStubPayload; } - std::vector<PrfBlock> generateTags(BSONValue payload) const { + bool isStub(Value elt) const override { + return isStubPayload; + } + + std::vector<PrfBlock> generateTags(BSONValue payload) const override { return stdx::visit( OverloadedVisitor{[&](BSONElement p) { if (p.isABSONObj()) { @@ -126,8 +130,6 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) { TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) { RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}}; - auto expCtx = make_intrusive<ExpressionContextForTest>(); std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"}; @@ -159,29 +161,98 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) { } } +TEST_F(RangePredicateRewriteTest, AggRangeRewrite_Stub) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); + + { + auto input = fromjson(str::stream() << "{$between: [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = ExpressionConstant::create(&_expCtx, Value(true)); + + _predicate.isStubPayload = true; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = ExpressionConstant::create(&_expCtx, Value(true)); + + _predicate.isStubPayload = true; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } +} + TEST_F(RangePredicateRewriteTest, AggRangeRewrite) { - auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); - auto inputExpr = - ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + { + auto op = "$between"; + auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); - auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); + auto actual = _predicate.rewrite(inputExpr.get()); - auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + { + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = + fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); - ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), - expected->serialize(false).getDocument().toBson()); + auto actual = _predicate.rewrite(inputExpr.get()); + + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + } } TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) { - auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); - auto inputExpr = - ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + { + auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); - auto expected = inputExpr; + auto expected = inputExpr; - _predicate.payloadValid = false; - auto actual = _predicate.rewrite(inputExpr.get()); - ASSERT(actual == nullptr); + _predicate.payloadValid = false; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual == nullptr); + } + { + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = + fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = inputExpr; + + _predicate.payloadValid = false; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual == nullptr); + } + } } BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) { @@ -219,6 +290,17 @@ std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx, return std::make_unique<ExpressionBetween>(expCtx, std::move(children)); } +std::unique_ptr<Expression> generateBetweenWithFFP( + ExpressionContext* expCtx, ExpressionCompare::CmpOp op, StringData path, int lb, int ub) { + auto ffp = Value(generateFFP(path, lb, ub, 0, 255).firstElement()); + auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp); + auto fieldpath = ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState); + std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath), + std::move(ffpExpr)}; + return std::make_unique<ExpressionCompare>(expCtx, op, std::move(children)); +} + TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { _mock.setForceEncryptedCollScanForTest(); auto expected = fromjson(R"({ @@ -273,9 +355,6 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) { _mock.setForceEncryptedCollScanForTest(); - auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35); - auto result = _predicate.rewrite(input.get()); - ASSERT(result); auto expected = fromjson(R"({ "$_internalFleBetween": { "field": "$age", @@ -309,7 +388,40 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) { } } })"); - ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + { + auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + } + { + auto ops = {ExpressionCompare::GT, + ExpressionCompare::GTE, + ExpressionCompare::LT, + ExpressionCompare::LTE}; + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + } + } +} + + +TEST_F(RangePredicateRewriteTest, UnsupportedComparisonOps) { + auto ops = {ExpressionCompare::CMP, ExpressionCompare::EQ, ExpressionCompare::NE}; + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + } + _mock.setForceEncryptedCollScanForTest(); + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + } } }; // namespace diff --git a/src/mongo/db/query/get_executor.cpp b/src/mongo/db/query/get_executor.cpp index 0824723b4d6..9e5d55a9d97 100644 --- a/src/mongo/db/query/get_executor.cpp +++ b/src/mongo/db/query/get_executor.cpp @@ -530,7 +530,9 @@ private: * - A vector of PlanStages, representing the roots of the constructed execution trees (in the * case when the query has multiple solutions, we may construct an execution tree for each * solution and pick the best plan after multi-planning). Elements of this vector can never be - * null. The size of this vector must always match the size of 'querySolutions' vector. + * null. The size of this vector must always be empty or match the size of 'querySolutions' + * vector. It will be empty in circumstances where we only construct query solutions and delay + * building execution trees, which is any time we are not using a cached plan. * - A root node of the extension plan. The plan can be combined with a solution to create a * larger plan after the winning solution is found. Can be null, meaning "no extension". * - An optional decisionWorks value, which is populated when a solution was reconstructed from @@ -546,9 +548,11 @@ public: using PlanStageVector = std::vector<std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>>; - void emplace(std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> root, - std::unique_ptr<QuerySolution> solution) { - _roots.push_back(std::move(root)); + void emplace(std::unique_ptr<QuerySolution> solution) { + // Only allow solutions to be added, execution trees will be generated later. + tassert(7087100, + "expected execution trees to be generated after query solutions", + _roots.empty()); _solutions.push_back(std::move(solution)); } @@ -560,11 +564,11 @@ public: std::string getPlanSummary() const { // We can report plan summary only if this result contains a single solution. - invariant(_roots.size() == 1); - invariant(_solutions.size() == 1); - invariant(_roots[0].first); + tassert(7087101, "expected exactly one solution", _solutions.size() == 1); + tassert(7087102, "expected at most one execution tree", _roots.size() <= 1); + // We only need the query solution to build the explain summary. auto explainer = plan_explainer_factory::make( - _roots[0].first.get(), &_roots[0].second, _solutions[0].get()); + nullptr /* root */, nullptr /* data */, _solutions[0].get()); return explainer->getPlanSummary(); } @@ -572,6 +576,14 @@ public: return std::make_pair(std::move(_roots), std::move(_solutions)); } + const QuerySolutionVector& solutions() const { + return _solutions; + } + + const PlanStageVector& roots() const { + return _roots; + } + boost::optional<size_t> decisionWorks() const { return _decisionWorks; } @@ -621,8 +633,14 @@ private: * normal stage builder process. * * We have a QuerySolutionNode tree (or multiple query solution trees), but must execute some * custom logic in order to build the final execution tree. + * + * In most cases, the helper bypasses the final step of building the execution tree and returns only + * the query solution(s) if the 'DeferExecutionTreeGeneration' flag is true. */ -template <typename KeyType, typename PlanStageType, typename ResultType> +template <typename KeyType, + typename PlanStageType, + typename ResultType, + bool DeferExecutionTreeGeneration> class PrepareExecutionHelper { public: PrepareExecutionHelper(OperationContext* opCtx, @@ -650,11 +668,8 @@ public: auto solution = std::make_unique<QuerySolution>(); solution->setRoot(std::make_unique<EofNode>()); - - auto root = buildExecutableTree(*solution); - auto result = makeResult(); - result->emplace(std::move(root), std::move(solution)); + addSolutionToResult(result.get(), std::move(solution)); return std::move(result); } @@ -679,12 +694,6 @@ public: } auto planCacheKey = buildPlanCacheKey(); - // Fill in some opDebug information, unless it has already been filled by an outer pipeline. - OpDebug& opDebug = CurOp::get(_opCtx)->debug(); - if (!opDebug.queryHash) { - opDebug.queryHash = planCacheKey.queryHash(); - } - if (auto result = buildCachedPlan(planCacheKey)) { return {std::move(result)}; } @@ -715,8 +724,7 @@ public: for (size_t i = 0; i < solutions.size(); ++i) { if (turnIxscanIntoCount(solutions[i].get())) { auto result = makeResult(); - auto root = buildExecutableTree(*solutions[i]); - result->emplace(std::move(root), std::move(solutions[i])); + addSolutionToResult(result.get(), std::move(solutions[i])); LOGV2_DEBUG(20925, 2, @@ -731,9 +739,8 @@ public: if (1 == solutions.size()) { // Only one possible plan. Build the stages from the solution. auto result = makeResult(); - auto root = buildExecutableTree(*solutions[0]); solutions[0]->indexFilterApplied = _plannerParams.indexFiltersApplied; - result->emplace(std::move(root), std::move(solutions[0])); + addSolutionToResult(result.get(), std::move(solutions[0])); LOGV2_DEBUG(20926, 2, @@ -748,6 +755,8 @@ public: } protected: + static constexpr bool ShouldDeferExecutionTreeGeneration = DeferExecutionTreeGeneration; + /** * Creates a result instance to be returned to the caller holding the result of the * prepare() call. @@ -757,6 +766,19 @@ protected: } /** + * Adds the query solution to the result object, additionally building the corresponding + * execution tree if 'DeferExecutionTreeGeneration' is turned on. + */ + void addSolutionToResult(ResultType* result, std::unique_ptr<QuerySolution> solution) { + if constexpr (!DeferExecutionTreeGeneration) { + auto root = buildExecutableTree(*solution); + result->emplace(std::move(root), std::move(solution)); + } else { + result->emplace(std::move(solution)); + } + } + + /** * Fills out planner parameters if not already filled. */ void initializePlannerParamsIfNeeded() { @@ -820,7 +842,8 @@ protected: class ClassicPrepareExecutionHelper final : public PrepareExecutionHelper<PlanCacheKey, std::unique_ptr<PlanStage>, - ClassicPrepareExecutionResult> { + ClassicPrepareExecutionResult, + false /* DeferExecutionTreeGeneration */> { public: ClassicPrepareExecutionHelper(OperationContext* opCtx, const CollectionPtr& collection, @@ -923,6 +946,12 @@ protected: const PlanCacheKey& planCacheKey) final { initializePlannerParamsIfNeeded(); + // Fill in some opDebug information, unless it has already been filled by an outer pipeline. + OpDebug& opDebug = CurOp::get(_opCtx)->debug(); + if (!opDebug.queryHash) { + opDebug.queryHash = planCacheKey.queryHash(); + } + // Before consulting the plan cache, check if we should short-circuit and construct a // find-by-_id plan. std::unique_ptr<ClassicPrepareExecutionResult> result = buildIdHackPlan(); @@ -1017,7 +1046,8 @@ class SlotBasedPrepareExecutionHelper final : public PrepareExecutionHelper< sbe::PlanCacheKey, std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>, - SlotBasedPrepareExecutionResult> { + SlotBasedPrepareExecutionResult, + true /* DeferExecutionTreeGeneration */> { public: using PrepareExecutionHelper::PrepareExecutionHelper; @@ -1038,6 +1068,9 @@ public: std::pair<std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> buildExecutableTree( const QuerySolution& solution) const final { + tassert(7087105, + "template indicates execution tree generation deferred", + !ShouldDeferExecutionTreeGeneration); return stage_builder::buildSlotBasedExecutableTree( _opCtx, _collections, *_cq, solution, _yieldPolicy); } @@ -1053,11 +1086,6 @@ protected: if (!feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) { return buildCachedPlanFromClassicCache(); } else { - OpDebug& opDebug = CurOp::get(_opCtx)->debug(); - if (!opDebug.planCacheKey) { - opDebug.planCacheKey = planCacheKey.planCacheKeyHash(); - } - auto&& planCache = sbe::getPlanCache(_opCtx); auto cacheEntry = planCache.getCacheEntryIfActive(planCacheKey); if (!cacheEntry) { @@ -1089,10 +1117,6 @@ protected: std::unique_ptr<SlotBasedPrepareExecutionResult> buildCachedPlanFromClassicCache() { const auto& mainColl = getMainCollection(); auto planCacheKey = plan_cache_key_factory::make<PlanCacheKey>(*_cq, mainColl); - OpDebug& opDebug = CurOp::get(_opCtx)->debug(); - if (!opDebug.planCacheKey) { - opDebug.planCacheKey = planCacheKey.planCacheKeyHash(); - } // Try to look up a cached solution for the query. if (auto cs = CollectionQueryInfo::get(mainColl).getPlanCache()->getCacheEntryIfActive( planCacheKey)) { @@ -1109,9 +1133,7 @@ protected: } auto result = makeResult(); - auto&& execTree = buildExecutableTree(*querySolution); - - result->emplace(std::move(execTree), std::move(querySolution)); + addSolutionToResult(result.get(), std::move(querySolution)); result->setDecisionWorks(cs->decisionWorks); result->setRecoveredFromPlanCache(true); return result; @@ -1134,8 +1156,7 @@ protected: for (size_t ix = 0; ix < solutions.size(); ++ix) { solutions[ix]->indexFilterApplied = _plannerParams.indexFiltersApplied; - auto execTree = buildExecutableTree(*solutions[ix]); - result->emplace(std::move(execTree), std::move(solutions[ix])); + addSolutionToResult(result.get(), std::move(solutions[ix])); } return result; } @@ -1256,29 +1277,40 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe OperationContext* opCtx, const MultipleCollectionAccessor& collections, std::unique_ptr<CanonicalQuery> cq, - PlanYieldPolicy::YieldPolicy requestedYieldPolicy, - const QueryPlannerParams& plannerParams) { - // Mark that this query uses the SBE engine, unless this has already been set. + std::unique_ptr<PlanYieldPolicySBE> yieldPolicy, + const QueryPlannerParams& plannerParams, + std::unique_ptr<SlotBasedPrepareExecutionResult> planningResult) { + // Now that we know what executor we are going to use, fill in some opDebug information, unless + // it has already been filled by an outer pipeline. OpDebug& opDebug = CurOp::get(opCtx)->debug(); if (!opDebug.classicEngineUsed) { opDebug.classicEngineUsed = false; } - - const auto mainColl = &collections.getMainCollection(); + if (collections.getMainCollection()) { + auto planCacheKey = plan_cache_key_factory::make(*cq, collections); + if (!opDebug.queryHash) { + opDebug.queryHash = planCacheKey.queryHash(); + } + if (!opDebug.planCacheKey && shouldCacheQuery(*cq)) { + opDebug.planCacheKey = planCacheKey.planCacheKeyHash(); + } + } // Analyze the provided query and build the list of candidate plans for it. auto nss = cq->nss(); - auto yieldPolicy = makeSbeYieldPolicy(opCtx, requestedYieldPolicy, mainColl, nss); - SlotBasedPrepareExecutionHelper helper{ - opCtx, collections, cq.get(), yieldPolicy.get(), plannerParams.options}; - auto planningResultWithStatus = helper.prepare(); - if (!planningResultWithStatus.isOK()) { - return planningResultWithStatus.getStatus(); - } - auto&& planningResult = planningResultWithStatus.getValue(); auto&& [roots, solutions] = planningResult->extractResultData(); + invariant(roots.empty() || roots.size() == solutions.size()); + if (roots.empty()) { + // We might have execution trees already if we pulled the plan from the cache. If not, we + // need to generate one for each solution. + for (const auto& solution : solutions) { + roots.emplace_back(stage_builder::buildSlotBasedExecutableTree( + opCtx, collections, *cq, *solution, yieldPolicy.get())); + } + } + // When query requires sub-planning, we may not get any executable plans. const auto planStageData = roots.empty() ? boost::none @@ -1318,7 +1350,8 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe *cq, std::move(solutions[0]), fillOutSecondaryCollectionsInformation(opCtx, collections, cq.get())); - roots[0] = helper.buildExecutableTree(*(solutions[0])); + roots[0] = stage_builder::buildSlotBasedExecutableTree( + opCtx, collections, *cq, *(solutions[0]), yieldPolicy.get()); } plan_cache_util::updatePlanCache(opCtx, collections, *cq, *solutions[0], *root, data); @@ -1338,13 +1371,99 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getSlotBasedExe std::move(yieldPolicy), planningResult->isRecoveredFromPlanCache()); } + +/** + * Checks if the result of query planning is SBE compatible. + */ +bool isPlanSbeCompatible(const SlotBasedPrepareExecutionResult& planningResult) { + const auto& solutions = planningResult.solutions(); + if (solutions.empty()) { + // Query needs subplanning (plans are generated later, we don't have access yet). + invariant(planningResult.needsSubplanning()); + } + + // Check that the query solution is SBE compatible. + + return true; +} + +/** + * Attempts to create a slot-based executor for the query, if the query plan is eligible for SBE + * execution. This function has three possible return values: + * + * 1. A plan executor. This is in the case where the query is SBE eligible and encounters no errors + * in executor creation. This result is to be expected in the majority of cases. + * 2. A non-OK status. This is when errors are encountered during executor creation. + * 3. The canonical query. This is to return ownership of the 'canonicalQuery' argument in the case + * where the query plan is not eligible for SBE execution but it is not an error case. + */ +StatusWith<stdx::variant<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>, + std::unique_ptr<CanonicalQuery>>> +attemptToGetSlotBasedExecutor( + OperationContext* opCtx, + const MultipleCollectionAccessor& collections, + std::unique_ptr<CanonicalQuery> canonicalQuery, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, + PlanYieldPolicy::YieldPolicy yieldPolicy, + const QueryPlannerParams& plannerParams) { + if (extractAndAttachPipelineStages) { + // Push down applicable pipeline stages and attach to the query, but don't remove from + // the high-level pipeline object until we know for sure we will execute with SBE. + extractAndAttachPipelineStages(canonicalQuery.get(), true /* attachOnly */); + } + + // Use SBE if we find any $group/$lookup stages eligible for execution in SBE or if SBE + // is fully enabled. Otherwise, fallback to the classic engine. + if (canonicalQuery->pipeline().empty() && + !feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) { + canonicalQuery->setSbeCompatible(false); + } else { + // Construct the SBE query solution - this will be our final decision stage to determine + // whether SBE should be used. + auto sbeYieldPolicy = makeSbeYieldPolicy( + opCtx, yieldPolicy, &collections.getMainCollection(), canonicalQuery->nss()); + + SlotBasedPrepareExecutionHelper helper{ + opCtx, collections, canonicalQuery.get(), sbeYieldPolicy.get(), plannerParams.options}; + auto planningResultWithStatus = helper.prepare(); + if (planningResultWithStatus.isOK() && + isPlanSbeCompatible(*planningResultWithStatus.getValue())) { + if (extractAndAttachPipelineStages) { + // We know now that we will use SBE, so we need to remove the pushed-down stages + // from the original pipeline object. + extractAndAttachPipelineStages(canonicalQuery.get(), false /* attachOnly */); + } + auto statusWithExecutor = + getSlotBasedExecutor(opCtx, + collections, + std::move(canonicalQuery), + std::move(sbeYieldPolicy), + plannerParams, + std::move(planningResultWithStatus.getValue())); + if (statusWithExecutor.isOK()) { + return std::move(statusWithExecutor.getValue()); + } else { + return statusWithExecutor.getStatus(); + } + } + + // Query plan was not SBE compatible - reset any fields that may have been modified, and + // fall back to classic engine. + canonicalQuery->setSbeCompatible(false); + canonicalQuery->setPipeline({}); + } + + // Return ownership of the canonical query to the caller. + return std::move(canonicalQuery); +} + } // namespace StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( OperationContext* opCtx, const MultipleCollectionAccessor& collections, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, PlanYieldPolicy::YieldPolicy yieldPolicy, const QueryPlannerParams& plannerParams) { invariant(canonicalQuery); @@ -1358,17 +1477,27 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( // Use SBE if 'canonicalQuery' is SBE compatible. if (!canonicalQuery->getForceClassicEngine() && canonicalQuery->isSbeCompatible()) { - if (extractAndAttachPipelineStages) { - extractAndAttachPipelineStages(canonicalQuery.get()); - } - // Use SBE if we find any $group/$lookup stages eligible for execution in SBE or if SBE - // is fully enabled. Otherwise, fallback to the classic engine. - if (canonicalQuery->pipeline().empty() && - !feature_flags::gFeatureFlagSbeFull.isEnabledAndIgnoreFCV()) { - canonicalQuery->setSbeCompatible(false); + auto statusWithExecutor = attemptToGetSlotBasedExecutor(opCtx, + collections, + std::move(canonicalQuery), + extractAndAttachPipelineStages, + yieldPolicy, + plannerParams); + if (!statusWithExecutor.isOK()) { + return statusWithExecutor.getStatus(); + } + auto& maybeExecutor = statusWithExecutor.getValue(); + if (stdx::holds_alternative<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>>( + maybeExecutor)) { + return std::move( + stdx::get<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>>(maybeExecutor)); } else { - return getSlotBasedExecutor( - opCtx, collections, std::move(canonicalQuery), yieldPolicy, plannerParams); + // The query is not eligible for SBE execution - reclaim the canonical query and fall + // back to classic. + tassert(7087103, + "return value must contain canonical query if not executor", + stdx::holds_alternative<std::unique_ptr<CanonicalQuery>>(maybeExecutor)); + canonicalQuery = std::move(stdx::get<std::unique_ptr<CanonicalQuery>>(maybeExecutor)); } } @@ -1380,7 +1509,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( OperationContext* opCtx, const CollectionPtr* collection, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, PlanYieldPolicy::YieldPolicy yieldPolicy, size_t plannerOptions) { MultipleCollectionAccessor multi{collection}; @@ -1400,7 +1529,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind OperationContext* opCtx, const MultipleCollectionAccessor& collections, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, bool permitYield, QueryPlannerParams plannerParams) { auto yieldPolicy = (permitYield && !opCtx->inMultiDocumentTransaction()) @@ -1423,7 +1552,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind OperationContext* opCtx, const CollectionPtr* coll, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, bool permitYield, size_t plannerOptions) { MultipleCollectionAccessor multi{*coll}; @@ -1612,6 +1741,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorDele // This is the regular path for when we have a CanonicalQuery. std::unique_ptr<CanonicalQuery> cq(parsedDelete->releaseParsedQuery()); + uassert(ErrorCodes::InternalErrorNotSupported, + "delete command is not eligible for bonsai", + !isEligibleForBonsai(*cq, opCtx, collection)); + // Transfer the explain verbosity level into the expression context. cq->getExpCtx()->explain = verbosity; @@ -1804,6 +1937,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorUpda // This is the regular path for when we have a CanonicalQuery. std::unique_ptr<CanonicalQuery> cq(parsedUpdate->releaseParsedQuery()); + uassert(ErrorCodes::InternalErrorNotSupported, + "update command is not eligible for bonsai", + !isEligibleForBonsai(*cq, opCtx, collection)); + std::unique_ptr<projection_ast::Projection> projection; if (!request->getProj().isEmpty()) { invariant(request->shouldReturnAnyDocs()); @@ -2129,6 +2266,10 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorCoun const auto skip = request.getSkip().value_or(0); const auto limit = request.getLimit().value_or(0); + uassert(ErrorCodes::InternalErrorNotSupported, + "count command is not eligible for bonsai", + !isEligibleForBonsai(*cq, opCtx, collection)); + if (!collection) { // Treat collections that do not exist as empty collections. Note that the explain reporting // machinery always assumes that the root stage for a count operation is a CountStage, so in @@ -2609,6 +2750,11 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorDist ? PlanYieldPolicy::YieldPolicy::INTERRUPT_ONLY : PlanYieldPolicy::YieldPolicy::YIELD_AUTO; + // Assert that not eligible for bonsai + uassert(ErrorCodes::InternalErrorNotSupported, + "distinct command is not eligible for bonsai", + !isEligibleForBonsai(*parsedDistinct->getQuery(), opCtx, collection)); + if (!collection) { // Treat collections that do not exist as empty collections. return plan_executor_factory::make(parsedDistinct->releaseQuery(), diff --git a/src/mongo/db/query/get_executor.h b/src/mongo/db/query/get_executor.h index 75ab5475f2f..a1e24952b76 100644 --- a/src/mongo/db/query/get_executor.h +++ b/src/mongo/db/query/get_executor.h @@ -155,7 +155,9 @@ bool shouldWaitForOplogVisibility(OperationContext* opCtx, * If the caller provides a 'extractAndAttachPipelineStages' function and the query is eligible for * pushdown into the find layer this function will be invoked to extract pipeline stages and * attach them to the provided 'CanonicalQuery'. This function should capture the Pipeline that - * stages should be extracted from. + * stages should be extracted from. If the boolean 'attachOnly' argument is true, it will only find + * and attach the applicable stages to the query. If it is false, it will remove the extracted + * stages from the pipeline. * * Note that the first overload takes a 'MultipleCollectionAccessor' and can construct a * PlanExecutor over multiple collections, while the second overload takes a single 'CollectionPtr' @@ -165,7 +167,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( OperationContext* opCtx, const MultipleCollectionAccessor& collections, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, PlanYieldPolicy::YieldPolicy yieldPolicy, const QueryPlannerParams& plannerOptions); @@ -173,7 +175,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( OperationContext* opCtx, const CollectionPtr* collection, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, PlanYieldPolicy::YieldPolicy yieldPolicy, size_t plannerOptions = 0); @@ -190,7 +192,9 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutor( * If the caller provides a 'extractAndAttachPipelineStages' function and the query is eligible for * pushdown into the find layer this function will be invoked to extract pipeline stages and * attach them to the provided 'CanonicalQuery'. This function should capture the Pipeline that - * stages should be extracted from. + * stages should be extracted from. If the boolean 'attachOnly' argument is true, it will only find + * and attach the applicable stages to the query. If it is false, it will remove the extracted + * stages from the pipeline. * * Note that the first overload takes a 'MultipleCollectionAccessor' and can construct a * PlanExecutor over multiple collections, while the second overload takes a single @@ -200,7 +204,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind OperationContext* opCtx, const MultipleCollectionAccessor& collections, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, bool permitYield = false, QueryPlannerParams plannerOptions = QueryPlannerParams{}); @@ -208,7 +212,7 @@ StatusWith<std::unique_ptr<PlanExecutor, PlanExecutor::Deleter>> getExecutorFind OperationContext* opCtx, const CollectionPtr* collection, std::unique_ptr<CanonicalQuery> canonicalQuery, - std::function<void(CanonicalQuery*)> extractAndAttachPipelineStages, + std::function<void(CanonicalQuery*, bool)> extractAndAttachPipelineStages, bool permitYield = false, size_t plannerOptions = QueryPlannerParams::DEFAULT); diff --git a/src/mongo/db/query/optimizer/SConscript b/src/mongo/db/query/optimizer/SConscript index a128b8eecc1..70ce6442e3a 100644 --- a/src/mongo/db/query/optimizer/SConscript +++ b/src/mongo/db/query/optimizer/SConscript @@ -91,17 +91,6 @@ env.Library( ], ) -# Default costing module. -env.Library( - target="optimizer_default_costing", - source=[ - "cascades/cost_derivation.cpp", - ], - LIBDEPS=[ - "optimizer_memo", - ], -) - # Main optimizer target. env.Library( target="optimizer", @@ -112,7 +101,6 @@ env.Library( LIBDEPS=[ "optimizer_cascades", "optimizer_default_ce", - "optimizer_default_costing", "optimizer_rewrites", ], ) @@ -126,6 +114,7 @@ env.Library( LIBDEPS=[ # We do not depend on the entire pipeline target. "$BUILD_DIR/mongo/db/pipeline/abt_utils", + "$BUILD_DIR/mongo/db/query/cost_model/query_cost_model", "$BUILD_DIR/mongo/db/query/optimizer/optimizer", "$BUILD_DIR/mongo/unittest/unittest", ], diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp b/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp deleted file mode 100644 index f58380b50b3..00000000000 --- a/src/mongo/db/query/optimizer/cascades/cost_derivation.cpp +++ /dev/null @@ -1,446 +0,0 @@ -/** - * Copyright (C) 2022-present MongoDB, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the Server Side Public License, version 1, - * as published by MongoDB, Inc. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Server Side Public License for more details. - * - * You should have received a copy of the Server Side Public License - * along with this program. If not, see - * <http://www.mongodb.com/licensing/server-side-public-license>. - * - * As a special exception, the copyright holders give permission to link the - * code of portions of this program with the OpenSSL library under certain - * conditions as described in each individual source file and distribute - * linked combinations including the program with the OpenSSL library. You - * must comply with the Server Side Public License in all respects for - * all of the code used other than as permitted herein. If you modify file(s) - * with this exception, you may extend this exception to your version of the - * file(s), but you are not obligated to do so. If you do not wish to do so, - * delete this exception statement from your version. If you delete this - * exception statement from all source files in the program, then also delete - * it in the license file. - */ - -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" -#include "mongo/db/query/optimizer/defs.h" - -namespace mongo::optimizer::cascades { - -using namespace properties; - -struct CostAndCEInternal { - CostAndCEInternal(double cost, CEType ce) : _cost(cost), _ce(ce) { - uassert(8423334, "Invalid cost.", !std::isnan(cost) && cost >= 0.0); - uassert(8423332, "Invalid cardinality", std::isfinite(ce) && ce >= 0.0); - } - double _cost; - CEType _ce; -}; - -class CostDerivation { - // These cost should reflect estimated aggregated execution time in milliseconds. - static constexpr double ms = 1.0e-3; - - // Startup cost of an operator. This is the minimal cost of an operator since it is - // present even if it doesn't process any input. - // TODO: calibrate the cost individually for each operator - static constexpr double kStartupCost = 0.000001; - - // TODO: collection scan should depend on the width of the doc. - // TODO: the actual measured cost is (0.4 * ms), however we increase it here because currently - // it is not possible to estimate the cost of a collection scan vs a full index scan. - static constexpr double kScanIncrementalCost = 0.6 * ms; - - // TODO: cost(N fields) ~ (0.55 + 0.025 * N) - static constexpr double kIndexScanIncrementalCost = 0.5 * ms; - - // TODO: cost(N fields) ~ 0.7 + 0.19 * N - static constexpr double kSeekCost = 2.0 * ms; - - // TODO: take the expression into account. - // cost(N conditions) = 0.2 + N * ??? - static constexpr double kFilterIncrementalCost = 0.2 * ms; - // TODO: the cost of projection depends on number of fields: cost(N fields) ~ 0.1 + 0.2 * N - static constexpr double kEvalIncrementalCost = 2.0 * ms; - - // TODO: cost(N fields) ~ 0.04 + 0.03*(N^2) - static constexpr double kGroupByIncrementalCost = 0.07 * ms; - static constexpr double kUnwindIncrementalCost = 0.03 * ms; // TODO: not yet calibrated - // TODO: not yet calibrated, should be at least as expensive as a filter - static constexpr double kBinaryJoinIncrementalCost = 0.2 * ms; - static constexpr double kHashJoinIncrementalCost = 0.05 * ms; // TODO: not yet calibrated - static constexpr double kMergeJoinIncrementalCost = 0.02 * ms; // TODO: not yet calibrated - - static constexpr double kUniqueIncrementalCost = 0.7 * ms; - - // TODO: implement collation cost that depends on number and size of sorted fields - // Based on a mix of int and str(64) fields: - // 1 sort field: sort_cost(N) = 1.0/10 * N * log(N) - // 5 sort fields: sort_cost(N) = 2.5/10 * N * log(N) - // 10 sort fields: sort_cost(N) = 3.0/10 * N * log(N) - // field_cost_coeff(F) ~ 0.75 + 0.2 * F - static constexpr double kCollationIncrementalCost = 2.5 * ms; // 5 fields avg - static constexpr double kCollationWithLimitIncrementalCost = - 1.0 * ms; // TODO: not yet calibrated - - static constexpr double kUnionIncrementalCost = 0.02 * ms; - - static constexpr double kExchangeIncrementalCost = 0.1 * ms; // TODO: not yet calibrated - -public: - CostAndCEInternal operator()(const ABT& /*n*/, const PhysicalScanNode& /*node*/) { - // Default estimate for scan. - const double collectionScanCost = - kStartupCost + kScanIncrementalCost * _cardinalityEstimate; - return {collectionScanCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const CoScanNode& /*node*/) { - // Assumed to be free. - return {kStartupCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const IndexScanNode& node) { - const double indexScanCost = - kStartupCost + kIndexScanIncrementalCost * _cardinalityEstimate; - return {indexScanCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const SeekNode& /*node*/) { - // SeekNode should deliver one result via cardinality estimate override. - // TODO: consider using node.getProjectionMap()._fieldProjections.size() to make the cost - // dependent on the size of the projection - const double seekCost = kStartupCost + kSeekCost * _cardinalityEstimate; - return {seekCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const MemoLogicalDelegatorNode& node) { - const LogicalProps& childLogicalProps = _memo.getLogicalProps(node.getGroupId()); - // Notice that unlike all physical nodes, this logical node takes it cardinality directly - // from the memo group logical property, ignoring _cardinalityEstimate. - CEType baseCE = getPropertyConst<CardinalityEstimate>(childLogicalProps).getEstimate(); - - if (hasProperty<IndexingRequirement>(_physProps)) { - const auto& indexingReq = getPropertyConst<IndexingRequirement>(_physProps); - if (indexingReq.getIndexReqTarget() == IndexReqTarget::Seek) { - // If we are performing a seek, normalize against the scan group cardinality. - const GroupIdType scanGroupId = - getPropertyConst<IndexingAvailability>(childLogicalProps).getScanGroupId(); - if (scanGroupId == node.getGroupId()) { - baseCE = 1.0; - } else { - const CEType scanGroupCE = - getPropertyConst<CardinalityEstimate>(_memo.getLogicalProps(scanGroupId)) - .getEstimate(); - if (scanGroupCE > 0.0) { - baseCE /= scanGroupCE; - } - } - } - } - - return {0.0, getAdjustedCE(baseCE, _physProps)}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const MemoPhysicalDelegatorNode& /*node*/) { - uasserted(6624040, "Should not be costing physical delegator nodes."); - } - - CostAndCEInternal operator()(const ABT& /*n*/, const FilterNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - double filterCost = childResult._cost; - if (!isTrivialExpr<EvalFilter>(node.getFilter())) { - // Non-trivial filter. - filterCost += kStartupCost + kFilterIncrementalCost * childResult._ce; - } - return {filterCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const EvaluationNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - double evalCost = childResult._cost; - if (!isTrivialExpr<EvalPath>(node.getProjection())) { - // Non-trivial projection. - evalCost += kStartupCost + kEvalIncrementalCost * _cardinalityEstimate; - } - return {evalCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const BinaryJoinNode& node) { - CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); - CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); - const double joinCost = kStartupCost + - kBinaryJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + - leftChildResult._cost + rightChildResult._cost; - return {joinCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const HashJoinNode& node) { - CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); - CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); - - // TODO: distinguish build side and probe side. - const double hashJoinCost = kStartupCost + - kHashJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + - leftChildResult._cost + rightChildResult._cost; - return {hashJoinCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const MergeJoinNode& node) { - CostAndCEInternal leftChildResult = deriveChild(node.getLeftChild(), 0); - CostAndCEInternal rightChildResult = deriveChild(node.getRightChild(), 1); - - const double mergeJoinCost = kStartupCost + - kMergeJoinIncrementalCost * (leftChildResult._ce + rightChildResult._ce) + - leftChildResult._cost + rightChildResult._cost; - - return {mergeJoinCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const UnionNode& node) { - const ABTVector& children = node.nodes(); - // UnionNode with one child is optimized away before lowering, therefore - // its cost is the cost of its child. - if (children.size() == 1) { - CostAndCEInternal childResult = deriveChild(children[0], 0); - return {childResult._cost, _cardinalityEstimate}; - } - - double totalCost = kStartupCost; - // The cost is the sum of the costs of its children and the cost to union each child. - for (size_t childIdx = 0; childIdx < children.size(); childIdx++) { - CostAndCEInternal childResult = deriveChild(children[childIdx], childIdx); - const double childCost = - childResult._cost + (childIdx > 0 ? kUnionIncrementalCost * childResult._ce : 0); - totalCost += childCost; - } - return {totalCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const GroupByNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - double groupByCost = kStartupCost; - - // TODO: for now pretend global group by is free. - if (node.getType() == GroupNodeType::Global) { - groupByCost += childResult._cost; - } else { - // TODO: consider RepetitionEstimate since this is a stateful operation. - groupByCost += kGroupByIncrementalCost * childResult._ce + childResult._cost; - } - return {groupByCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const UnwindNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - // Unwind probably depends mostly on its output size. - const double unwindCost = kUnwindIncrementalCost * _cardinalityEstimate + childResult._cost; - return {unwindCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const UniqueNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - const double uniqueCost = - kStartupCost + kUniqueIncrementalCost * childResult._ce + childResult._cost; - return {uniqueCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const CollationNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - // TODO: consider RepetitionEstimate since this is a stateful operation. - - double logFactor = childResult._ce; - double incrConst = kCollationIncrementalCost; - if (hasProperty<LimitSkipRequirement>(_physProps)) { - if (auto limit = getPropertyConst<LimitSkipRequirement>(_physProps).getAbsoluteLimit(); - limit < logFactor) { - logFactor = limit; - incrConst = kCollationWithLimitIncrementalCost; - } - } - - // Notice that log2(x) < 0 for any x < 1, and log2(1) = 0. Generally it makes sense that - // there is no cost to sort 1 document, so the only cost left is the startup cost. - const double sortCost = kStartupCost + childResult._cost + - ((logFactor <= 1.0) - ? 0.0 - // TODO: The cost formula below is based on 1 field, mix of int and str. Instead we - // have to take into account the number and size of sorted fields. - : incrConst * childResult._ce * std::log2(logFactor)); - return {sortCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const LimitSkipNode& node) { - // Assumed to be free. - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - const double limitCost = kStartupCost + childResult._cost; - return {limitCost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const ExchangeNode& node) { - CostAndCEInternal childResult = deriveChild(node.getChild(), 0); - double localCost = kStartupCost + kExchangeIncrementalCost * _cardinalityEstimate; - - switch (node.getProperty().getDistributionAndProjections()._type) { - case DistributionType::Replicated: - localCost *= 2.0; - break; - - case DistributionType::HashPartitioning: - case DistributionType::RangePartitioning: - localCost *= 1.1; - break; - - default: - break; - } - - return {localCost + childResult._cost, _cardinalityEstimate}; - } - - CostAndCEInternal operator()(const ABT& /*n*/, const RootNode& node) { - return deriveChild(node.getChild(), 0); - } - - /** - * Other ABT types. - */ - template <typename T, typename... Ts> - CostAndCEInternal operator()(const ABT& /*n*/, const T& /*node*/, Ts&&...) { - static_assert(!canBePhysicalNode<T>(), "Physical node must implement its cost derivation."); - return {0.0, 0.0}; - } - - static CostAndCEInternal derive(const Metadata& metadata, - const Memo& memo, - const PhysProps& physProps, - const ABT::reference_type physNodeRef, - const ChildPropsType& childProps, - const NodeCEMap& nodeCEMap) { - CostAndCEInternal result = - deriveInternal(metadata, memo, physProps, physNodeRef, childProps, nodeCEMap); - - switch (getPropertyConst<DistributionRequirement>(physProps) - .getDistributionAndProjections() - ._type) { - case DistributionType::Centralized: - case DistributionType::Replicated: - break; - - case DistributionType::RoundRobin: - case DistributionType::HashPartitioning: - case DistributionType::RangePartitioning: - case DistributionType::UnknownPartitioning: - result._cost /= metadata._numberOfPartitions; - break; - - default: - MONGO_UNREACHABLE; - } - - return result; - } - -private: - CostDerivation(const Metadata& metadata, - const Memo& memo, - const CEType ce, - const PhysProps& physProps, - const ChildPropsType& childProps, - const NodeCEMap& nodeCEMap) - : _metadata(metadata), - _memo(memo), - _physProps(physProps), - _cardinalityEstimate(getAdjustedCE(ce, _physProps)), - _childProps(childProps), - _nodeCEMap(nodeCEMap) {} - - template <class T> - static bool isTrivialExpr(const ABT& n) { - if (n.is<Constant>() || n.is<Variable>()) { - return true; - } - if (const auto* ptr = n.cast<T>(); ptr != nullptr && - ptr->getPath().template is<PathIdentity>() && isTrivialExpr<T>(ptr->getInput())) { - return true; - } - return false; - } - - static CostAndCEInternal deriveInternal(const Metadata& metadata, - const Memo& memo, - const PhysProps& physProps, - const ABT::reference_type physNodeRef, - const ChildPropsType& childProps, - const NodeCEMap& nodeCEMap) { - auto it = nodeCEMap.find(physNodeRef.cast<Node>()); - bool found = (it != nodeCEMap.cend()); - uassert(8423330, - "Only MemoLogicalDelegatorNode can be missing from nodeCEMap.", - found || physNodeRef.is<MemoLogicalDelegatorNode>()); - const CEType ce = (found ? it->second : 0.0); - - CostDerivation instance(metadata, memo, ce, physProps, childProps, nodeCEMap); - CostAndCEInternal costCEestimates = physNodeRef.visit(instance); - return costCEestimates; - } - - CostAndCEInternal deriveChild(const ABT& child, const size_t childIndex) { - PhysProps physProps = _childProps.empty() ? _physProps : _childProps.at(childIndex).second; - return deriveInternal(_metadata, _memo, physProps, child.ref(), {}, _nodeCEMap); - } - - static CEType getAdjustedCE(CEType baseCE, const PhysProps& physProps) { - CEType result = baseCE; - - // First: correct for un-enforced limit. - if (hasProperty<LimitSkipRequirement>(physProps)) { - const auto limit = getPropertyConst<LimitSkipRequirement>(physProps).getAbsoluteLimit(); - if (result > limit) { - result = limit; - } - } - - // Second: correct for enforced limit. - if (hasProperty<LimitEstimate>(physProps)) { - const auto limit = getPropertyConst<LimitEstimate>(physProps).getEstimate(); - if (result > limit) { - result = limit; - } - } - - // Third: correct for repetition. - if (hasProperty<RepetitionEstimate>(physProps)) { - result *= getPropertyConst<RepetitionEstimate>(physProps).getEstimate(); - } - - return result; - } - - // We don't own this. - const Metadata& _metadata; - const Memo& _memo; - const PhysProps& _physProps; - const CEType _cardinalityEstimate; - const ChildPropsType& _childProps; - const NodeCEMap& _nodeCEMap; -}; - -CostAndCE DefaultCosting::deriveCost(const Metadata& metadata, - const Memo& memo, - const PhysProps& physProps, - const ABT::reference_type physNodeRef, - const ChildPropsType& childProps, - const NodeCEMap& nodeCEMap) const { - const CostAndCEInternal result = - CostDerivation::derive(metadata, memo, physProps, physNodeRef, childProps, nodeCEMap); - return {CostType::fromDouble(result._cost), result._ce}; -} - -} // namespace mongo::optimizer::cascades diff --git a/src/mongo/db/query/optimizer/cascades/implementers.cpp b/src/mongo/db/query/optimizer/cascades/implementers.cpp index 6039ac3cb43..835045a2611 100644 --- a/src/mongo/db/query/optimizer/cascades/implementers.cpp +++ b/src/mongo/db/query/optimizer/cascades/implementers.cpp @@ -189,8 +189,7 @@ public: if (node.getArraySize() == 0) { nodeCEMap.emplace(physNode.cast<Node>(), 0.0); - physNode = - make<LimitSkipNode>(properties::LimitSkipRequirement{0, 0}, std::move(physNode)); + physNode = make<LimitSkipNode>(LimitSkipRequirement{0, 0}, std::move(physNode)); nodeCEMap.emplace(physNode.cast<Node>(), 0.0); for (const ProjectionName& boundProjName : node.binder().names()) { @@ -208,8 +207,7 @@ public: } else { nodeCEMap.emplace(physNode.cast<Node>(), 1.0); - physNode = - make<LimitSkipNode>(properties::LimitSkipRequirement{1, 0}, std::move(physNode)); + physNode = make<LimitSkipNode>(LimitSkipRequirement{1, 0}, std::move(physNode)); nodeCEMap.emplace(physNode.cast<Node>(), 1.0); const ProjectionName valueScanProj = _prefixId.getNextId("valueScan"); @@ -710,7 +708,21 @@ public: return; } const bool isIndex = indexReqTarget == IndexReqTarget::Index; - if (isIndex && (!node.hasLeftIntervals() || !node.hasRightIntervals())) { + + const GroupIdType leftGroupId = + node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + const GroupIdType rightGroupId = + node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + + const LogicalProps& leftLogicalProps = _memo.getLogicalProps(leftGroupId); + const LogicalProps& rightLogicalProps = _memo.getLogicalProps(rightGroupId); + + const bool hasProperIntervalLeft = + getPropertyConst<IndexingAvailability>(leftLogicalProps).hasProperInterval(); + const bool hasProperIntervalRight = + getPropertyConst<IndexingAvailability>(rightLogicalProps).hasProperInterval(); + + if (isIndex && (!hasProperIntervalLeft || !hasProperIntervalRight)) { // We need to have proper intervals on both sides. return; } @@ -729,14 +741,6 @@ public: } } - const GroupIdType leftGroupId = - node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); - const GroupIdType rightGroupId = - node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); - - const LogicalProps& leftLogicalProps = _memo.getLogicalProps(leftGroupId); - const LogicalProps& rightLogicalProps = _memo.getLogicalProps(rightGroupId); - const CEType intersectedCE = getPropertyConst<CardinalityEstimate>(_logicalProps).getEstimate(); const CEType leftCE = getPropertyConst<CardinalityEstimate>(leftLogicalProps).getEstimate(); @@ -1660,9 +1664,9 @@ void addImplementers(const Metadata& metadata, const QueryHints& hints, const RIDProjectionsMap& ridProjections, PrefixId& prefixId, - const properties::PhysProps& physProps, + const PhysProps& physProps, PhysQueueAndImplPos& queue, - const properties::LogicalProps& logicalProps, + const LogicalProps& logicalProps, const OrderPreservingABTSet& logicalNodes, const PathToIntervalFn& pathToInterval) { ImplementationVisitor visitor(metadata, diff --git a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp index c2ef5529e79..29d4ccc22ec 100644 --- a/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp +++ b/src/mongo/db/query/optimizer/cascades/logical_props_derivation.cpp @@ -222,8 +222,7 @@ public: indexingAvailability.setEqPredsOnly(computeEqPredsOnly(node.getReqMap())); } - auto& satisfiedPartialIndexes = - getProperty<IndexingAvailability>(result).getSatisfiedPartialIndexes(); + auto& satisfiedPartialIndexes = indexingAvailability.getSatisfiedPartialIndexes(); for (const auto& [indexDefName, indexDef] : scanDef.getIndexDefs()) { if (!indexDef.getPartialReqMap().empty()) { auto intersection = node.getReqMap(); @@ -237,6 +236,8 @@ public: } } + indexingAvailability.setHasProperInterval(hasProperIntervals(node.getReqMap())); + return maybeUpdateNodePropsMap(node, std::move(result)); } diff --git a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp index 421d03f73a1..0c9a923b2ed 100644 --- a/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp +++ b/src/mongo/db/query/optimizer/cascades/logical_rewriter.cpp @@ -941,8 +941,6 @@ struct SplitRequirementsResult { PartialSchemaRequirements _rightReqs; bool _hasFieldCoverage = true; - bool _hasLeftIntervals = false; - bool _hasRightIntervals = false; }; /** @@ -981,7 +979,6 @@ static SplitRequirementsResult splitRequirements( size_t index = 0; for (const auto& [key, req] : reqMap) { - const bool fullyOpenInterval = isFullyOpen.at(index); if (((1ull << index) & mask) != 0) { bool addedToLeft = false; @@ -1005,7 +1002,7 @@ static SplitRequirementsResult splitRequirements( // We cannot return index values if our interval can possibly contain Null. Instead, // we remove the output binding for the left side, and return the value from the // right (seek) side. - if (!fullyOpenInterval) { + if (!isFullyOpen.at(index)) { addRequirement( leftReqs, key, boost::none /*boundProjectionName*/, req.getIntervals()); addedToLeft = true; @@ -1018,9 +1015,6 @@ static SplitRequirementsResult splitRequirements( } if (addedToLeft) { - if (!fullyOpenInterval) { - result._hasLeftIntervals = true; - } if (indexFieldPrefixMapForScanDef) { if (auto pathPtr = key._path.cast<PathGet>(); pathPtr != nullptr && indexFieldPrefixMapForScanDef->count(pathPtr->name()) == 0) { @@ -1033,9 +1027,6 @@ static SplitRequirementsResult splitRequirements( } } else if (isIndex || !req.getIsPerfOnly()) { addRequirement(rightReqs, key, req.getBoundProjectionName(), req.getIntervals()); - if (!fullyOpenInterval) { - result._hasRightIntervals = true; - } } index++; } @@ -1141,10 +1132,13 @@ struct ExploreConvert<SargableNode> { continue; } - if (isIndex && (!splitResult._hasLeftIntervals || !splitResult._hasRightIntervals)) { - // Reject. Must have at least one proper interval on either side. + // Reject. Must have at least one proper interval on either side. + if (isIndex && + (!hasProperIntervals(splitResult._leftReqs) || + !hasProperIntervals(splitResult._rightReqs))) { continue; } + if (!splitResult._hasFieldCoverage) { // Reject rewrite. No suitable indexes. continue; @@ -1196,11 +1190,8 @@ struct ExploreConvert<SargableNode> { isIndex ? IndexReqTarget::Index : IndexReqTarget::Seek, scanDelegator); - ABT newRoot = make<RIDIntersectNode>(scanProjectionName, - splitResult._hasLeftIntervals, - splitResult._hasRightIntervals, - std::move(leftChild), - std::move(rightChild)); + ABT newRoot = make<RIDIntersectNode>( + scanProjectionName, std::move(leftChild), std::move(rightChild)); const auto& result = ctx.addNode(newRoot, false /*substitute*/); for (const MemoLogicalNodeId nodeId : result.second) { @@ -1290,14 +1281,27 @@ void reorderAgainstRIDIntersectNode(ABT::reference_type aboveNode, } const RIDIntersectNode& node = *belowNode.cast<RIDIntersectNode>(); - if (node.hasLeftIntervals() && hasLeftRef) { + const GroupIdType groupIdLeft = + node.getLeftChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + const bool hasProperIntervalLeft = + properties::getPropertyConst<properties::IndexingAvailability>( + ctx.getMemo().getLogicalProps(groupIdLeft)) + .hasProperInterval(); + if (hasProperIntervalLeft && hasLeftRef) { defaultReorder<AboveNode, RIDIntersectNode, DefaultChildAccessor, LeftChildAccessor, false /*substitute*/>(aboveNode, belowNode, ctx); } - if (node.hasRightIntervals() && hasRightRef) { + + const GroupIdType groupIdRight = + node.getRightChild().cast<MemoLogicalDelegatorNode>()->getGroupId(); + const bool hasProperIntervalRight = + properties::getPropertyConst<properties::IndexingAvailability>( + ctx.getMemo().getLogicalProps(groupIdRight)) + .hasProperInterval(); + if (hasProperIntervalRight && hasRightRef) { defaultReorder<AboveNode, RIDIntersectNode, DefaultChildAccessor, diff --git a/src/mongo/db/query/optimizer/cascades/memo_defs.h b/src/mongo/db/query/optimizer/cascades/memo_defs.h index b667df2816b..f0e476ab819 100644 --- a/src/mongo/db/query/optimizer/cascades/memo_defs.h +++ b/src/mongo/db/query/optimizer/cascades/memo_defs.h @@ -59,6 +59,10 @@ public: OrderPreservingABTSet(const OrderPreservingABTSet&) = delete; OrderPreservingABTSet(OrderPreservingABTSet&&) = default; + OrderPreservingABTSet& operator=(const OrderPreservingABTSet&) = delete; + OrderPreservingABTSet& operator=(OrderPreservingABTSet&&) = default; + + ABT::reference_type at(size_t index) const; std::pair<size_t, bool> emplace_back(ABT node); std::pair<size_t, bool> find(ABT::reference_type node) const; diff --git a/src/mongo/db/query/optimizer/explain.cpp b/src/mongo/db/query/optimizer/explain.cpp index ba5aecdbd73..5e9df5df3bb 100644 --- a/src/mongo/db/query/optimizer/explain.cpp +++ b/src/mongo/db/query/optimizer/explain.cpp @@ -496,7 +496,7 @@ private: void addValue(sbe::value::TypeTags tag, sbe::value::Value val, const bool append = false) { if (!_initialized) { _initialized = true; - _canAppend = !_nextFieldName.empty(); + _canAppend = _nextFieldName.has_value(); if (_canAppend) { std::tie(_tag, _val) = sbe::value::makeNewObject(); } else { @@ -512,7 +512,7 @@ private: } if (append) { - uassert(6624073, "Field name is not empty", _nextFieldName.empty()); + uassert(6624073, "Field name is not set", !_nextFieldName.has_value()); uassert(6624349, "Other printer does not contain Object", tag == sbe::value::TypeTags::Object); @@ -523,19 +523,19 @@ private: addField(obj->field(i), fieldTag, fieldVal); } } else { - addField(_nextFieldName, tag, val); - _nextFieldName.clear(); + tassert(6751700, "Missing field name to serialize", _nextFieldName); + addField(*_nextFieldName, tag, val); + _nextFieldName = boost::none; } } void addField(const std::string& fieldName, sbe::value::TypeTags tag, sbe::value::Value val) { - uassert(6624074, "Field name is empty", !fieldName.empty()); uassert(6624075, "Duplicate field name", _fieldNameSet.insert(fieldName).second); sbe::value::getObjectView(_val)->push_back(fieldName, tag, val); } void reset() { - _nextFieldName.clear(); + _nextFieldName = boost::none; _initialized = false; _canAppend = false; _tag = sbe::value::TypeTags::Nothing; @@ -543,7 +543,8 @@ private: _fieldNameSet.clear(); } - std::string _nextFieldName; + // Cannot assume empty means non-existent, so use optional<>. + boost::optional<std::string> _nextFieldName; bool _initialized; bool _canAppend; sbe::value::TypeTags _tag; @@ -1276,8 +1277,6 @@ public: printer.separator(" [") .fieldName("scanProjectionName", ExplainVersion::V3) .print(node.getScanProjectionName()); - printBooleanFlag(printer, "hasLeftIntervals", node.hasLeftIntervals()); - printBooleanFlag(printer, "hasRightIntervals", node.hasRightIntervals()); printer.separator("]"); nodeCEPropsPrint(printer, n, node); @@ -1766,6 +1765,7 @@ public: .fieldName("scanDefName") .print(prop.getScanDefName()); printBooleanFlag(printer, "eqPredsOnly", prop.getEqPredsOnly()); + printBooleanFlag(printer, "hasProperInterval", prop.hasProperInterval()); printer.separator("]"); if (!prop.getSatisfiedPartialIndexes().empty()) { @@ -2263,7 +2263,7 @@ public: ExplainPrinter printer("PathGet"); printer.separator(" [") .fieldName("path", ExplainVersion::V3) - .print(path.name()) + .print(path.name().empty() ? "<empty>" : path.name()) .separator("]") .setChildCount(1) .fieldName("input", ExplainVersion::V3) @@ -2542,8 +2542,12 @@ static void printBSONstr(PrinterType& printer, } } -std::string ExplainGenerator::printBSON(const sbe::value::TypeTags tag, - const sbe::value::Value val) { +std::string ExplainGenerator::explainBSONStr(const ABT& node, + bool displayProperties, + const cascades::MemoExplainInterface* memoInterface, + const NodeToGroupPropsMap& nodeMap) { + const auto [tag, val] = explainBSON(node, displayProperties, memoInterface, nodeMap); + sbe::value::ValueGuard vg(tag, val); ExplainPrinterImpl<ExplainVersion::V2> printer; printBSONstr(printer, tag, val); return printer.str(); diff --git a/src/mongo/db/query/optimizer/explain.h b/src/mongo/db/query/optimizer/explain.h index ad62dd54126..19e9221f8dd 100644 --- a/src/mongo/db/query/optimizer/explain.h +++ b/src/mongo/db/query/optimizer/explain.h @@ -94,7 +94,10 @@ public: const cascades::MemoExplainInterface* memoInterface = nullptr, const NodeToGroupPropsMap& nodeMap = {}); - static std::string printBSON(sbe::value::TypeTags tag, sbe::value::Value val); + static std::string explainBSONStr(const ABT& node, + bool displayProperties = false, + const cascades::MemoExplainInterface* memoInterface = nullptr, + const NodeToGroupPropsMap& nodeMap = {}); static std::string explainLogicalProps(const std::string& description, const properties::LogicalProps& props); diff --git a/src/mongo/db/query/optimizer/interval_intersection_test.cpp b/src/mongo/db/query/optimizer/interval_intersection_test.cpp index 119a1b98fea..cc8127d49b7 100644 --- a/src/mongo/db/query/optimizer/interval_intersection_test.cpp +++ b/src/mongo/db/query/optimizer/interval_intersection_test.cpp @@ -56,12 +56,12 @@ std::string optimizedQueryPlan(const std::string& query, ABT translated = translatePipeline(metadata, "[{$match: " + query + "}]", scanDefName, prefixId); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase, - OptPhase::MemoExplorationPhase, - OptPhase::MemoImplementationPhase}, - prefixId, - metadata, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase, + OptPhase::MemoExplorationPhase, + OptPhase::MemoImplementationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); ABT optimized = translated; phaseManager.getHints()._disableScan = true; diff --git a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp index 2434e811de1..03980eb5db4 100644 --- a/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/logical_rewriter_optimizer_test.cpp @@ -74,10 +74,10 @@ TEST(LogicalRewriter, RootNodeMerge) { " Source []\n", rootNode); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT rewritten = std::move(rootNode); phaseManager.optimize(rewritten); @@ -288,10 +288,10 @@ TEST(LogicalRewriter, FilterProjectRewrite) { " Source []\n", rootNode); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -399,10 +399,10 @@ TEST(LogicalRewriter, FilterProjectComplexRewrite) { " Source []\n", rootNode); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -477,10 +477,10 @@ TEST(LogicalRewriter, FilterProjectGroupRewrite) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"c"}}, std::move(filterANode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -547,10 +547,10 @@ TEST(LogicalRewriter, FilterProjectUnwindRewrite) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, std::move(filterBNode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -618,10 +618,10 @@ TEST(LogicalRewriter, FilterProjectExchangeRewrite) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, std::move(filterANode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -690,10 +690,10 @@ TEST(LogicalRewriter, UnwindCollationRewrite) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"a", "b"}}, std::move(unwindNode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -802,11 +802,11 @@ TEST(LogicalRewriter, FilterUnionReorderSingleProjection) { " Source []\n", latest); - OptPhaseManager phaseManager( - {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, - prefixId, - {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = + makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); phaseManager.optimize(latest); ASSERT_EXPLAIN_V2( @@ -966,11 +966,11 @@ TEST(LogicalRewriter, MultipleFilterUnionReorder) { " Source []\n", latest); - OptPhaseManager phaseManager( - {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, - prefixId, - {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = + makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); phaseManager.optimize(latest); ASSERT_EXPLAIN_V2( @@ -1070,12 +1070,12 @@ TEST(LogicalRewriter, FilterUnionUnionPushdown) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}}, std::move(filter)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test1", createScanDef({}, {})}, - {"test2", createScanDef({}, {})}, - {"test3", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test1", createScanDef({}, {})}, + {"test2", createScanDef({}, {})}, + {"test3", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); ASSERT_EXPLAIN_V2( @@ -1216,10 +1216,11 @@ TEST(LogicalRewriter, UnionPreservesCommonLogicalProps) { // Run the reordering rewrite such that the scan produces a hash partition. PrefixId prefixId; - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, - prefixId, - metadata, - DebugInfo::kDefaultForTests); + auto phaseManager = + makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, + prefixId, + metadata, + DebugInfo::kDefaultForTests); ABT optimized = rootNode; phaseManager.optimize(optimized); @@ -1432,10 +1433,11 @@ TEST(LogicalRewriter, SargableCE) { PrefixId prefixId; ABT rootNode = sargableCETestSetup(); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = + makePhaseManager({OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); @@ -1474,7 +1476,8 @@ TEST(LogicalRewriter, SargableCE) { " | | projections: \n" " | | ptest\n" " | | indexingAvailability: \n" - " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly]\n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly, " + "hasProperInterval]\n" " | | collectionAvailability: \n" " | | test\n" " | | distributionAvailability: \n" @@ -1508,7 +1511,8 @@ TEST(LogicalRewriter, SargableCE) { " | | projections: \n" " | | ptest\n" " | | indexingAvailability: \n" - " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly]\n" + " | | [groupId: 0, scanProjection: ptest, scanDefName: test, eqPredsOnly, " + "hasProperInterval]\n" " | | collectionAvailability: \n" " | | test\n" " | | distributionAvailability: \n" @@ -1540,10 +1544,10 @@ TEST(LogicalRewriter, RemoveNoopFilter) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"ptest"}}, std::move(filterANode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo::kDefaultForTests); ABT latest = std::move(rootNode); phaseManager.optimize(latest); diff --git a/src/mongo/db/query/optimizer/node.cpp b/src/mongo/db/query/optimizer/node.cpp index d9b0086fb1e..1eed39980b1 100644 --- a/src/mongo/db/query/optimizer/node.cpp +++ b/src/mongo/db/query/optimizer/node.cpp @@ -290,15 +290,9 @@ bool EvaluationNode::operator==(const EvaluationNode& other) const { getChild() == other.getChild(); } -RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName, - const bool hasLeftIntervals, - const bool hasRightIntervals, - ABT leftChild, - ABT rightChild) +RIDIntersectNode::RIDIntersectNode(ProjectionName scanProjectionName, ABT leftChild, ABT rightChild) : Base(std::move(leftChild), std::move(rightChild)), - _scanProjectionName(std::move(scanProjectionName)), - _hasLeftIntervals(hasLeftIntervals), - _hasRightIntervals(hasRightIntervals) { + _scanProjectionName(std::move(scanProjectionName)) { assertNodeSort(getLeftChild()); assertNodeSort(getRightChild()); } @@ -321,23 +315,13 @@ ABT& RIDIntersectNode::getRightChild() { bool RIDIntersectNode::operator==(const RIDIntersectNode& other) const { return _scanProjectionName == other._scanProjectionName && - _hasLeftIntervals == other._hasLeftIntervals && - _hasRightIntervals == other._hasRightIntervals && getLeftChild() == other.getLeftChild() && - getRightChild() == other.getRightChild(); + getLeftChild() == other.getLeftChild() && getRightChild() == other.getRightChild(); } const ProjectionName& RIDIntersectNode::getScanProjectionName() const { return _scanProjectionName; } -bool RIDIntersectNode::hasLeftIntervals() const { - return _hasLeftIntervals; -} - -bool RIDIntersectNode::hasRightIntervals() const { - return _hasRightIntervals; -} - static ProjectionNameVector createSargableBindings(const PartialSchemaRequirements& reqMap) { ProjectionNameVector result; for (const auto& entry : reqMap) { diff --git a/src/mongo/db/query/optimizer/node.h b/src/mongo/db/query/optimizer/node.h index 155f1d69f29..880d2fd5480 100644 --- a/src/mongo/db/query/optimizer/node.h +++ b/src/mongo/db/query/optimizer/node.h @@ -389,11 +389,7 @@ class RIDIntersectNode final : public Operator<2>, public ExclusivelyLogicalNode using Base = Operator<2>; public: - RIDIntersectNode(ProjectionName scanProjectionName, - bool hasLeftIntervals, - bool hasRightIntervals, - ABT leftChild, - ABT rightChild); + RIDIntersectNode(ProjectionName scanProjectionName, ABT leftChild, ABT rightChild); bool operator==(const RIDIntersectNode& other) const; @@ -405,15 +401,8 @@ public: const ProjectionName& getScanProjectionName() const; - bool hasLeftIntervals() const; - bool hasRightIntervals() const; - private: const ProjectionName _scanProjectionName; - - // If true left and right children have at least one proper interval (not fully open). - const bool _hasLeftIntervals; - const bool _hasRightIntervals; }; /** diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.cpp b/src/mongo/db/query/optimizer/opt_phase_manager.cpp index 4c4fd0deed3..a00e30aa499 100644 --- a/src/mongo/db/query/optimizer/opt_phase_manager.cpp +++ b/src/mongo/db/query/optimizer/opt_phase_manager.cpp @@ -30,7 +30,6 @@ #include "mongo/db/query/optimizer/opt_phase_manager.h" #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" #include "mongo/db/query/optimizer/rewrites/path.h" @@ -49,22 +48,6 @@ OptPhaseManager::PhaseSet OptPhaseManager::_allRewrites = {OptPhase::ConstEvalPr OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet, PrefixId& prefixId, - Metadata metadata, - DebugInfo debugInfo, - QueryHints queryHints) - : OptPhaseManager(std::move(phaseSet), - prefixId, - false /*requireRID*/, - std::move(metadata), - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, - std::move(debugInfo), - std::move(queryHints)) {} - -OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet, - PrefixId& prefixId, const bool requireRID, Metadata metadata, std::unique_ptr<CEInterface> ceDerivation, diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.h b/src/mongo/db/query/optimizer/opt_phase_manager.h index 0a94ffb9761..81a849a3461 100644 --- a/src/mongo/db/query/optimizer/opt_phase_manager.h +++ b/src/mongo/db/query/optimizer/opt_phase_manager.h @@ -77,11 +77,6 @@ public: OptPhaseManager(PhaseSet phaseSet, PrefixId& prefixId, - Metadata metadata, - DebugInfo debugInfo, - QueryHints queryHints = {}); - OptPhaseManager(PhaseSet phaseSet, - PrefixId& prefixId, bool requireRID, Metadata metadata, std::unique_ptr<CEInterface> ceDerivation, diff --git a/src/mongo/db/query/optimizer/optimizer_failure_test.cpp b/src/mongo/db/query/optimizer/optimizer_failure_test.cpp index bf4715548aa..6ab998dcf13 100644 --- a/src/mongo/db/query/optimizer/optimizer_failure_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_failure_test.cpp @@ -28,7 +28,6 @@ */ #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/metadata_factory.h" #include "mongo/db/query/optimizer/node.h" #include "mongo/db/query/optimizer/opt_phase_manager.h" @@ -53,11 +52,11 @@ DEATH_TEST_REGEX(Optimizer, HitIterationLimitInrunStructuralPhases, "Tripwire as ABT evalNode = make<EvaluationNode>("evalProj1", Constant::int64(5), std::move(scanNode)); - OptPhaseManager phaseManager( - {OptPhase::PathFuse, OptPhase::ConstEvalPre}, - prefixId, - {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, - DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); + auto phaseManager = + makePhaseManager({OptPhase::PathFuse, OptPhase::ConstEvalPre}, + prefixId, + {{{"test1", createScanDef({}, {})}, {"test2", createScanDef({}, {})}}}, + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); ASSERT_THROWS_CODE(phaseManager.optimize(evalNode), DBException, 6808700); } @@ -81,10 +80,10 @@ DEATH_TEST_REGEX(Optimizer, ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode)); - OptPhaseManager phaseManager({OptPhase::MemoSubstitutionPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); + auto phaseManager = makePhaseManager({OptPhase::MemoSubstitutionPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808702); } @@ -108,10 +107,10 @@ DEATH_TEST_REGEX(Optimizer, ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode)); - OptPhaseManager phaseManager({OptPhase::MemoExplorationPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); + auto phaseManager = makePhaseManager({OptPhase::MemoExplorationPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808702); } @@ -133,10 +132,10 @@ DEATH_TEST_REGEX(Optimizer, BadGroupID, "Tripwire assertion.*6808704") { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{{}}, std::move(filterNode)); - OptPhaseManager phaseManager({OptPhase::MemoImplementationPhase}, - prefixId, - {{{"test", createScanDef({}, {})}}}, - DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); + auto phaseManager = makePhaseManager({OptPhase::MemoImplementationPhase}, + prefixId, + {{{"test", createScanDef({}, {})}}}, + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, 0)); ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808704); } @@ -161,13 +160,13 @@ DEATH_TEST_REGEX(Optimizer, EnvHasFreeVariables, "Tripwire assertion.*6808711") ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p3"}}, std::move(filter2Node)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, {{{"test", createScanDef({}, {})}}}, - {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests)); ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808711); } @@ -206,12 +205,11 @@ DEATH_TEST_REGEX(Optimizer, FailedToRetrieveRID, "Tripwire assertion.*6808705") ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManagerRequireRID( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - true /*requireRID*/, {{{"c1", createScanDef( {}, @@ -224,11 +222,7 @@ DEATH_TEST_REGEX(Optimizer, FailedToRetrieveRID, "Tripwire assertion.*6808705") ConstEval::constFold, {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))})}}, 5 /*numberOfPartitions*/}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, - {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); + DebugInfo(true, DebugInfo::kDefaultDebugLevelForTests, DebugInfo::kIterationLimitForTests)); ASSERT_THROWS_CODE(phaseManager.optimize(rootNode), DBException, 6808705); } diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp index d52b84470a1..423a6a399af 100644 --- a/src/mongo/db/query/optimizer/optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -29,18 +29,45 @@ #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/node.h" -#include "mongo/db/query/optimizer/opt_phase_manager.h" #include "mongo/db/query/optimizer/reference_tracker.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" #include "mongo/db/query/optimizer/syntax/syntax.h" -#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h" #include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/db/query/optimizer/utils/utils.h" #include "mongo/unittest/unittest.h" + namespace mongo::optimizer { namespace { +TEST(Optimizer, AutoUpdateExplain) { + ABT tree = make<BinaryOp>(Operations::Add, + Constant::int64(1), + make<Variable>("very very very very very very very very very very " + "very very long variable name with \"quotes\"")); + + /** + * To exercise the auto-updating behavior: + * 1. Change the flag "kAutoUpdateOnFailure" to "true". + * 2. Induce a failure: change something in the expected output. + * 3. Recompile and run the test binary as normal. + * 4. Observe after the run the test file is updated with the correct output. + */ + ASSERT_EXPLAIN_V2_AUTO( // NOLINT (test auto-update) + "BinaryOp [Add]\n" + "| Variable [very very very very very very very very very very very very long variable " + "name with \"quotes\"]\n" + "Const [1]\n", + tree); + + // Test for short constant. It should not be inlined. The nolint comment on the string constant + // itself is auto-generated. + ABT tree1 = make<Variable>("short name"); + ASSERT_EXPLAIN_V2_AUTO( // NOLINT (test auto-update) + "Variable [short name]\n", // NOLINT (test auto-update) + tree1); +} + Constant* constEval(ABT& tree) { auto env = VariableEnvironment::build(tree); ConstEval evaluator{env}; @@ -746,13 +773,14 @@ TEST(Explain, ExplainV2Compact) { TEST(Explain, ExplainBsonForConstant) { ABT cNode = Constant::int64(3); - auto [tag, val] = ExplainGenerator::explainBSON(cNode); - sbe::value::ValueGuard vg(tag, val); - ASSERT_EQ( - "{\n nodeType: \"Const\", \n" + + ASSERT_EXPLAIN_BSON( + "{\n" + " nodeType: \"Const\", \n" " tag: \"NumberInt64\", \n" - " value: 3\n}\n", - ExplainGenerator::printBSON(tag, val)); + " value: 3\n" + "}\n", + cNode); } } // namespace diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp index 1079e45acfa..789bcdeea81 100644 --- a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp @@ -30,7 +30,6 @@ #include "mongo/db/pipeline/abt/utils.h" #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" #include "mongo/db/query/optimizer/cascades/ce_hinted.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/cascades/rewriter_rules.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/metadata_factory.h" @@ -66,7 +65,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p2"}}, std::move(filter2Node)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -121,7 +120,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) { "| | p1\n" "| | p2\n" "| | indexingAvailability: \n" - "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test, hasProperInterval]\n" "| | collectionAvailability: \n" "| | test\n" "| | distributionAvailability: \n" @@ -147,7 +146,7 @@ TEST(PhysRewriter, PhysicalRewriterBasic) { "| | p1\n" "| | p2\n" "| | indexingAvailability: \n" - "| | [groupId: 0, scanProjection: p1, scanDefName: test]\n" + "| | [groupId: 0, scanProjection: p1, scanDefName: test, hasProperInterval]\n" "| | collectionAvailability: \n" "| | test\n" "| | distributionAvailability: \n" @@ -271,7 +270,7 @@ TEST(PhysRewriter, GroupBy) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"c"}}, std::move(filterANode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -341,7 +340,7 @@ TEST(PhysRewriter, GroupBy1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -413,7 +412,7 @@ TEST(PhysRewriter, Unwind) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"a", "b"}}, std::move(filterBNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -491,7 +490,7 @@ TEST(PhysRewriter, DuplicateFilter) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -552,7 +551,7 @@ TEST(PhysRewriter, FilterCollation) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pb"}}, std::move(limitSkipNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -608,7 +607,7 @@ TEST(PhysRewriter, EvalCollation) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -663,7 +662,7 @@ TEST(PhysRewriter, FilterEvalCollation) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -719,7 +718,7 @@ TEST(PhysRewriter, FilterIndexing) { { PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase}, prefixId, {{{"c1", @@ -739,7 +738,7 @@ TEST(PhysRewriter, FilterIndexing) { "| | root\n" "| RefBlock: \n" "| Variable [root]\n" - "RIDIntersect [root, hasLeftIntervals]\n" + "RIDIntersect [root]\n" "| Scan [c1]\n" "| BindBlock:\n" "| [root]\n" @@ -762,7 +761,7 @@ TEST(PhysRewriter, FilterIndexing) { { PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -804,7 +803,7 @@ TEST(PhysRewriter, FilterIndexing) { { PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -868,7 +867,7 @@ TEST(PhysRewriter, FilterIndexing1) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"p1"}}, std::move(filterNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -933,7 +932,7 @@ TEST(PhysRewriter, FilterIndexing2) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1013,7 +1012,7 @@ TEST(PhysRewriter, FilterIndexing2NonSarg) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1131,7 +1130,7 @@ TEST(PhysRewriter, FilterIndexing3) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1186,7 +1185,7 @@ TEST(PhysRewriter, FilterIndexing3MultiKey) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1280,7 +1279,7 @@ TEST(PhysRewriter, FilterIndexing4) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterDNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1388,7 +1387,7 @@ TEST(PhysRewriter, FilterIndexing5) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1478,7 +1477,7 @@ TEST(PhysRewriter, FilterIndexing6) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pb"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1541,7 +1540,7 @@ TEST(PhysRewriter, FilterIndexingStress) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1641,7 +1640,7 @@ TEST(PhysRewriter, FilterIndexingVariable) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1737,7 +1736,7 @@ TEST(PhysRewriter, FilterIndexingMaxKey) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -1805,7 +1804,7 @@ TEST(PhysRewriter, SargableProjectionRenames) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(evalNode2)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase}, prefixId, {{{"c1", createScanDef({}, {})}}}, @@ -1867,7 +1866,7 @@ TEST(PhysRewriter, SargableAcquireProjection) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(evalNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase}, prefixId, {{{"c1", createScanDef({}, {})}}}, @@ -1933,17 +1932,13 @@ TEST(PhysRewriter, FilterReorder) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(result)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef({}, {})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = std::move(rootNode); @@ -2028,21 +2023,17 @@ TEST(PhysRewriter, CoveredScan) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending, false /*isMultiKey*/)}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = std::move(rootNode); @@ -2100,7 +2091,7 @@ TEST(PhysRewriter, EvalIndexing) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); { - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2133,7 +2124,7 @@ TEST(PhysRewriter, EvalIndexing) { { // Index and collation node have incompatible ops. - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2192,7 +2183,7 @@ TEST(PhysRewriter, EvalIndexing1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2262,7 +2253,7 @@ TEST(PhysRewriter, EvalIndexing2) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa2"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::ConstEvalPre, OptPhase::PathFuse, OptPhase::MemoSubstitutionPhase, @@ -2347,12 +2338,11 @@ TEST(PhysRewriter, MultiKeyIndex) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -2360,9 +2350,6 @@ TEST(PhysRewriter, MultiKeyIndex) { {"index2", makeIndexDefinition("b", CollationOp::Descending, false /*isMultiKey*/)}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); { @@ -2571,7 +2558,7 @@ TEST(PhysRewriter, CompoundIndex1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterDNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2656,7 +2643,7 @@ TEST(PhysRewriter, CompoundIndex2) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2743,7 +2730,7 @@ TEST(PhysRewriter, CompoundIndex3) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2830,12 +2817,11 @@ TEST(PhysRewriter, CompoundIndex4Negative) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode)); // Create the following indexes: {a:1, c:1, {name: 'index1'}}, and {b:1, d:1, {name: 'index2'}} - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -2848,9 +2834,6 @@ TEST(PhysRewriter, CompoundIndex4Negative) { {makeNonMultikeyIndexPath("d"), CollationOp::Ascending}}, false /*isMultiKey*/}}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -2905,7 +2888,7 @@ TEST(PhysRewriter, IndexBoundsIntersect) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -2980,7 +2963,7 @@ TEST(PhysRewriter, IndexBoundsIntersect1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3048,7 +3031,7 @@ TEST(PhysRewriter, IndexBoundsIntersect2) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3122,7 +3105,7 @@ TEST(PhysRewriter, IndexBoundsIntersect3) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3201,7 +3184,7 @@ TEST(PhysRewriter, IndexResidualReq) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3229,7 +3212,7 @@ TEST(PhysRewriter, IndexResidualReq) { "| | pa\n" "| | root\n" "| | indexingAvailability: \n" - "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1, hasProperInterval]\n" "| | collectionAvailability: \n" "| | c1\n" "| | distributionAvailability: \n" @@ -3257,7 +3240,7 @@ TEST(PhysRewriter, IndexResidualReq) { "| | pa\n" "| | root\n" "| | indexingAvailability: \n" - "| | [groupId: 0, scanProjection: root, scanDefName: c1]\n" + "| | [groupId: 0, scanProjection: root, scanDefName: c1, hasProperInterval]\n" "| | collectionAvailability: \n" "| | c1\n" "| | distributionAvailability: \n" @@ -3321,7 +3304,7 @@ TEST(PhysRewriter, IndexResidualReq1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(collationNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3404,7 +3387,7 @@ TEST(PhysRewriter, IndexResidualReq2) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterBNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -3483,18 +3466,13 @@ TEST(PhysRewriter, ElemMatchIndex) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef({}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending)}})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -3567,22 +3545,17 @@ TEST(PhysRewriter, ElemMatchIndex1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef({}, {{"index1", makeCompositeIndexDefinition( {{"b", CollationOp::Ascending, true /*isMultiKey*/}, {"a", CollationOp::Ascending, true /*isMultiKey*/}})}})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -3650,21 +3623,16 @@ TEST(PhysRewriter, ElemMatchIndexNoArrays) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending, false /*multiKey*/)}})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -3720,22 +3688,17 @@ TEST(PhysRewriter, ObjectElemMatchResidual) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef({}, {{"index1", makeCompositeIndexDefinition( {{"b", CollationOp::Ascending, true /*isMultiKey*/}, {"a", CollationOp::Ascending, true /*isMultiKey*/}})}})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -3840,12 +3803,11 @@ TEST(PhysRewriter, ObjectElemMatchBounds) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -3855,10 +3817,6 @@ TEST(PhysRewriter, ObjectElemMatchBounds) { true /*isMultiKey*/}}} )}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -3928,21 +3886,16 @@ TEST(PhysRewriter, NestedElemMatch) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"coll1", createScanDef( {}, {{"index1", makeIndexDefinition("a", CollationOp::Ascending, true /*isMultiKey*/)}})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - defaultConvertPathToInterval, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -4049,12 +4002,11 @@ TEST(PhysRewriter, PathObj) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); PrefixId prefixId; - OptPhaseManager phaseManager{ + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef({}, {{"index1", @@ -4062,13 +4014,8 @@ TEST(PhysRewriter, PathObj) { {{"a", CollationOp::Ascending, false /*isMultiKey*/}, {"b", CollationOp::Ascending, true /*isMultiKey*/}})}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - // The path-to-interval callback is important for this test. - // We want to confirm PathObj becomes an interval. - defaultConvertPathToInterval, - ConstEval::constFold, DebugInfo{true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}, - {} /*hints*/}; + {} /*hints*/); ABT optimized = rootNode; phaseManager.optimize(optimized); @@ -4139,7 +4086,7 @@ TEST(PhysRewriter, ArrayConstantIndex) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4244,7 +4191,7 @@ TEST(PhysRewriter, ArrayConstantNoIndex) { ABT rootNode = make<RootNode>(properties::ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode2)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4306,7 +4253,7 @@ TEST(PhysRewriter, ParallelScan) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4368,7 +4315,7 @@ TEST(PhysRewriter, HashPartitioning) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4452,12 +4399,11 @@ TEST(PhysRewriter, IndexPartitioning) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -4471,9 +4417,6 @@ TEST(PhysRewriter, IndexPartitioning) { {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("b"))})}}, 5 /*numberOfPartitions*/}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -4573,12 +4516,11 @@ TEST(PhysRewriter, IndexPartitioning1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -4598,9 +4540,6 @@ TEST(PhysRewriter, IndexPartitioning1) { {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("c"))})}}, 5 /*numberOfPartitions*/}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -4653,7 +4592,7 @@ TEST(PhysRewriter, LocalGlobalAgg) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa", "pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4731,7 +4670,7 @@ TEST(PhysRewriter, LocalGlobalAgg1) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pc"}}, std::move(groupByNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4786,7 +4725,7 @@ TEST(PhysRewriter, LocalLimitSkip) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(limitSkipNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -4920,7 +4859,7 @@ TEST(PhysRewriter, CollationLimit) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(limitSkipNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5061,7 +5000,7 @@ TEST(PhysRewriter, PartialIndex1) { ASSERT_TRUE(conversionResult); ASSERT_FALSE(conversionResult->_retainPredicate); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5142,7 +5081,7 @@ TEST(PhysRewriter, PartialIndex2) { ASSERT_TRUE(conversionResult); ASSERT_FALSE(conversionResult->_retainPredicate); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5222,7 +5161,7 @@ TEST(PhysRewriter, PartialIndexReject) { ASSERT_TRUE(conversionResult); ASSERT_FALSE(conversionResult->_retainPredicate); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5287,17 +5226,12 @@ TEST(PhysRewriter, RequireRID) { ABT rootNode = make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManagerRequireRID( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - true /*requireRID*/, {{{"c1", createScanDef({}, {})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -5342,17 +5276,12 @@ TEST(PhysRewriter, RequireRID1) { std::move(filterNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManagerRequireRID( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - true /*requireRID*/, {{{"c1", createScanDef({}, {})}}}, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -5410,7 +5339,7 @@ TEST(PhysRewriter, UnionRewrite) { std::move(unionNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5479,7 +5408,7 @@ TEST(PhysRewriter, JoinRewrite) { std::move(joinNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5554,7 +5483,7 @@ TEST(PhysRewriter, JoinRewrite1) { std::move(joinNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5609,7 +5538,7 @@ TEST(PhysRewriter, RootInterval) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"root"}}, std::move(filterNode)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5663,7 +5592,7 @@ TEST(PhysRewriter, EqMemberSargable) { { PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase}, prefixId, {{{"c1", @@ -5711,7 +5640,7 @@ TEST(PhysRewriter, EqMemberSargable) { { PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5811,7 +5740,7 @@ TEST(PhysRewriter, IndexSubfieldCovered) { std::move(filterNode3)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, @@ -5895,12 +5824,11 @@ TEST(PhysRewriter, PerfOnlyPreds1) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode2)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -5909,9 +5837,6 @@ TEST(PhysRewriter, PerfOnlyPreds1) { {"a", CollationOp::Ascending, false /*isMultiKey*/}}, false /*isMultiKey*/)}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; @@ -5988,12 +5913,11 @@ TEST(PhysRewriter, PerfOnlyPreds2) { make<RootNode>(ProjectionRequirement{ProjectionNameVector{"pa"}}, std::move(filterNode2)); PrefixId prefixId; - OptPhaseManager phaseManager( + auto phaseManager = makePhaseManager( {OptPhase::MemoSubstitutionPhase, OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, - false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -6001,15 +5925,12 @@ TEST(PhysRewriter, PerfOnlyPreds2) { {"index2", makeIndexDefinition("b", CollationOp::Ascending, false /*isMultiKey*/)}})}}}, std::make_unique<HintedCE>(std::move(hints)), - std::make_unique<DefaultCosting>(), - {} /*pathToInterval*/, - ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; phaseManager.getHints()._disableYieldingTolerantPlans = false; phaseManager.optimize(optimized); - ASSERT_BETWEEN(10, 15, phaseManager.getMemo().getStats()._physPlanExplorationCount); + ASSERT_BETWEEN(10, 17, phaseManager.getMemo().getStats()._physPlanExplorationCount); // Demonstrate an intersection plan, with predicates repeated on the Seek side. ASSERT_EXPLAIN_V2Compact( @@ -6042,16 +5963,16 @@ TEST(PhysRewriter, PerfOnlyPreds2) { "| Variable [rid_0]\n" "MergeJoin []\n" "| | | Condition\n" - "| | | rid_0 = rid_3\n" + "| | | rid_0 = rid_5\n" "| | Collation\n" "| | Ascending\n" "| Union []\n" "| | BindBlock:\n" - "| | [rid_3]\n" + "| | [rid_5]\n" "| | Source []\n" "| Evaluation []\n" "| | BindBlock:\n" - "| | [rid_3]\n" + "| | [rid_5]\n" "| | Variable [rid_0]\n" "| IndexScan [{'<rid>': rid_0}, scanDefName: c1, indexDefName: index2, interval: {[Const " "[2], Const [2]]}]\n" diff --git a/src/mongo/db/query/optimizer/props.cpp b/src/mongo/db/query/optimizer/props.cpp index 854882292be..c0079a78ed2 100644 --- a/src/mongo/db/query/optimizer/props.cpp +++ b/src/mongo/db/query/optimizer/props.cpp @@ -282,17 +282,20 @@ IndexingAvailability::IndexingAvailability(GroupIdType scanGroupId, ProjectionName scanProjection, std::string scanDefName, const bool eqPredsOnly, + const bool hasProperInterval, opt::unordered_set<std::string> satisfiedPartialIndexes) : _scanGroupId(scanGroupId), _scanProjection(std::move(scanProjection)), _scanDefName(std::move(scanDefName)), _eqPredsOnly(eqPredsOnly), - _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)) {} + _satisfiedPartialIndexes(std::move(satisfiedPartialIndexes)), + _hasProperInterval(hasProperInterval) {} bool IndexingAvailability::operator==(const IndexingAvailability& other) const { return _scanGroupId == other._scanGroupId && _scanProjection == other._scanProjection && _scanDefName == other._scanDefName && _eqPredsOnly == other._eqPredsOnly && - _satisfiedPartialIndexes == other._satisfiedPartialIndexes; + _satisfiedPartialIndexes == other._satisfiedPartialIndexes && + _hasProperInterval == other._hasProperInterval; } GroupIdType IndexingAvailability::getScanGroupId() const { @@ -327,6 +330,14 @@ void IndexingAvailability::setEqPredsOnly(const bool value) { _eqPredsOnly = value; } +bool IndexingAvailability::hasProperInterval() const { + return _hasProperInterval; +} + +void IndexingAvailability::setHasProperInterval(const bool hasProperInterval) { + _hasProperInterval = hasProperInterval; +} + CollectionAvailability::CollectionAvailability(opt::unordered_set<std::string> scanDefSet) : _scanDefSet(std::move(scanDefSet)) {} diff --git a/src/mongo/db/query/optimizer/props.h b/src/mongo/db/query/optimizer/props.h index adfb0f82847..e7ac16f227d 100644 --- a/src/mongo/db/query/optimizer/props.h +++ b/src/mongo/db/query/optimizer/props.h @@ -399,6 +399,7 @@ public: ProjectionName scanProjection, std::string scanDefName, bool eqPredsOnly, + bool hasProperInterval, opt::unordered_set<std::string> satisfiedPartialIndexes); bool operator==(const IndexingAvailability& other) const; @@ -415,6 +416,9 @@ public: bool getEqPredsOnly() const; void setEqPredsOnly(bool value); + bool hasProperInterval() const; + void setHasProperInterval(bool hasProperInterval); + private: GroupIdType _scanGroupId; const ProjectionName _scanProjection; @@ -427,6 +431,9 @@ private: // Set of indexes with partial indexes whose partial filters are satisfied for the current // group. opt::unordered_set<std::string> _satisfiedPartialIndexes; + + // True if there is at least one proper interval in a sargable node in this group. + bool _hasProperInterval; }; diff --git a/src/mongo/db/query/optimizer/utils/abt_hash.cpp b/src/mongo/db/query/optimizer/utils/abt_hash.cpp index 90585939882..b9f8d5a2517 100644 --- a/src/mongo/db/query/optimizer/utils/abt_hash.cpp +++ b/src/mongo/db/query/optimizer/utils/abt_hash.cpp @@ -195,8 +195,6 @@ public: size_t rightChildResult) { // Specifically always including children. return computeHashSeq<45>(std::hash<ProjectionName>()(node.getScanProjectionName()), - std::hash<bool>()(node.hasLeftIntervals()), - std::hash<bool>()(node.hasRightIntervals()), leftChildResult, rightChildResult); } diff --git a/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp index 446d975c16d..8c807453232 100644 --- a/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp +++ b/src/mongo/db/query/optimizer/utils/unit_test_pipeline_utils.cpp @@ -32,9 +32,9 @@ #include "mongo/db/pipeline/abt/document_source_visitor.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" -#include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/unittest/temp_dir.h" @@ -212,15 +212,7 @@ ABT optimizeABT(ABT abt, bool phaseManagerDisableScan) { PrefixId prefixId; - OptPhaseManager phaseManager(phaseSet, - prefixId, - false, - metadata, - std::make_unique<HeuristicCE>(), - std::make_unique<DefaultCosting>(), - pathToInterval, - ConstEval::constFold, - DebugInfo::kDefaultForTests); + auto phaseManager = makePhaseManager(phaseSet, prefixId, metadata, DebugInfo::kDefaultForTests); if (phaseManagerDisableScan) { phaseManager.getHints()._disableScan = true; } diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp index 5ad7c6615c5..538a4b1f036 100644 --- a/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp +++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.cpp @@ -29,30 +29,172 @@ #include "mongo/db/query/optimizer/utils/unit_test_utils.h" +#include <fstream> + +#include "mongo/db/pipeline/abt/utils.h" +#include "mongo/db/query/cost_model/cost_estimator.h" +#include "mongo/db/query/cost_model/cost_model_manager.h" +#include "mongo/db/query/optimizer/cascades/ce_heuristic.h" #include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/metadata.h" #include "mongo/db/query/optimizer/node.h" -#include "mongo/unittest/unittest.h" +#include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/util/str_escape.h" namespace mongo::optimizer { static constexpr bool kDebugAsserts = false; +// DO NOT COMMIT WITH "TRUE". +static constexpr bool kAutoUpdateOnFailure = false; +static constexpr const char* kTempFileSuffix = ".tmp.txt"; + +// Map from file name to a list of updates. We keep track of how many lines are added or deleted at +// a particular line of a source file. +using LineDeltaVector = std::vector<std::pair<uint64_t, int64_t>>; +std::map<std::string, LineDeltaVector> gLineDeltaMap; + + void maybePrintABT(const ABT& abt) { // Always print using the supported versions to make sure we don't crash. const std::string strV1 = ExplainGenerator::explain(abt); const std::string strV2 = ExplainGenerator::explainV2(abt); const std::string strV2Compact = ExplainGenerator::explainV2Compact(abt); - auto [tag, val] = ExplainGenerator::explainBSON(abt); - sbe::value::ValueGuard vg(tag, val); + const std::string strBSON = ExplainGenerator::explainBSONStr(abt); if constexpr (kDebugAsserts) { std::cout << "V1: " << strV1 << "\n"; std::cout << "V2: " << strV2 << "\n"; std::cout << "V2Compact: " << strV2Compact << "\n"; - std::cout << "BSON: " << ExplainGenerator::printBSON(tag, val) << "\n"; + std::cout << "BSON: " << strBSON << "\n"; + } +} + +static std::vector<std::string> formatStr(const std::string& str) { + std::vector<std::string> replacementLines; + std::istringstream lineInput(str); + + // Account for maximum line length after linting. We need to indent, add quotes, etc. + static constexpr size_t kEscapedLength = 88; + + std::string line; + while (std::getline(lineInput, line)) { + // Read the string line by line and format it to match the test file's expected format. We + // have an initial indentation, followed by quotes and the escaped string itself. + + std::string escaped = mongo::str::escapeForJSON(line); + for (;;) { + // If the line is estimated to exceed the maximum length allowed by the linter, break it + // up and make sure to insert '\n' only at the end of the last segment. + const bool breakupLine = escaped.size() > kEscapedLength; + + std::ostringstream os; + os << " \"" << escaped.substr(0, kEscapedLength); + if (!breakupLine) { + os << "\\n"; + } + os << "\"\n"; + replacementLines.push_back(os.str()); + + if (breakupLine) { + escaped = escaped.substr(kEscapedLength); + } else { + break; + } + } + } + + if (!replacementLines.empty() && !replacementLines.back().empty()) { + // Account for the fact that we need an extra comma after the string constant in the macro. + auto& lastLine = replacementLines.back(); + lastLine.insert(lastLine.size() - 1, ","); + + if (replacementLines.size() == 1) { + // For single lines, add a 'nolint' comment to prevent the linter from inlining the + // single line with the macro itself. + lastLine.insert(lastLine.size() - 1, " // NOLINT (test auto-update)"); + } + } + + return replacementLines; +} + +bool handleAutoUpdate(const std::string& expected, + const std::string& actual, + const std::string& fileName, + const size_t lineNumber) { + if (expected == actual) { + return true; + } + if constexpr (!kAutoUpdateOnFailure) { + std::cout << "Auto-updating is disabled.\n"; + return false; + } + + const auto expectedFormatted = formatStr(expected); + const auto actualFormatted = formatStr(actual); + + std::cout << "Updating expected result in file '" << fileName << "', line: " << lineNumber + << ".\n"; + std::cout << "Replacement:\n"; + for (const auto& line : actualFormatted) { + std::cout << line; + } + + // Compute the total number of lines added or removed before the current macro line. + auto& lineDeltas = gLineDeltaMap.emplace(fileName, LineDeltaVector{}).first->second; + int64_t totalDelta = 0; + for (const auto& [line, delta] : lineDeltas) { + if (line < lineNumber) { + totalDelta += delta; + } } + + const size_t replacementEndLine = lineNumber + totalDelta; + // Treat an empty string as needing one line. Adjust for line delta. + const size_t replacementStartLine = + replacementEndLine - (expectedFormatted.empty() ? 1 : expectedFormatted.size()); + + const std::string tempFileName = fileName + kTempFileSuffix; + std::string line; + size_t lineIndex = 0; + + try { + std::ifstream in; + in.open(fileName); + std::ofstream out; + out.open(tempFileName); + + // Generate a new test file, updated with the replacement string. + while (std::getline(in, line)) { + lineIndex++; + + if (lineIndex < replacementStartLine || lineIndex >= replacementEndLine) { + out << line << "\n"; + } else if (lineIndex == replacementStartLine) { + for (const auto& line1 : actualFormatted) { + out << line1; + } + } + } + + out.close(); + in.close(); + + std::rename(tempFileName.c_str(), fileName.c_str()); + } catch (const std::exception& ex) { + // Print and re-throw exception. + std::cout << "Caught an exception while manipulating files: " << ex.what(); + throw ex; + } + + // Add the current delta. + lineDeltas.emplace_back( + lineNumber, static_cast<int64_t>(actualFormatted.size()) - expectedFormatted.size()); + + // Do not assert in order to allow multiple tests to be updated. + return true; } ABT makeIndexPath(FieldPathType fieldPath, bool isMultiKey) { @@ -96,4 +238,60 @@ IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFi return IndexDefinition{std::move(idxCollSpec), isMultiKey}; } +std::unique_ptr<CostingInterface> makeCosting() { + return std::make_unique<cost_model::CostEstimator>( + cost_model::CostModelManager().getDefaultCoefficients()); +} + +OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + DebugInfo debugInfo, + QueryHints queryHints) { + return OptPhaseManager{std::move(phaseSet), + prefixId, + false /*requireRID*/, + std::move(metadata), + std::make_unique<HeuristicCE>(), + makeCosting(), + defaultConvertPathToInterval, + ConstEval::constFold, + std::move(debugInfo), + std::move(queryHints)}; +} + +OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + std::unique_ptr<CEInterface> ceDerivation, + DebugInfo debugInfo, + QueryHints queryHints) { + return OptPhaseManager{std::move(phaseSet), + prefixId, + false /*requireRID*/, + std::move(metadata), + std::move(ceDerivation), + makeCosting(), + defaultConvertPathToInterval, + ConstEval::constFold, + std::move(debugInfo), + std::move(queryHints)}; +} + +OptPhaseManager makePhaseManagerRequireRID(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + DebugInfo debugInfo, + QueryHints queryHints) { + return OptPhaseManager{std::move(phaseSet), + prefixId, + true /*requireRID*/, + std::move(metadata), + std::make_unique<HeuristicCE>(), + makeCosting(), + defaultConvertPathToInterval, + ConstEval::constFold, + std::move(debugInfo), + std::move(queryHints)}; +} } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/unit_test_utils.h b/src/mongo/db/query/optimizer/utils/unit_test_utils.h index b0ad4704b84..e358dc13495 100644 --- a/src/mongo/db/query/optimizer/utils/unit_test_utils.h +++ b/src/mongo/db/query/optimizer/utils/unit_test_utils.h @@ -31,6 +31,7 @@ #include "mongo/db/bson/dotted_path_support.h" #include "mongo/db/query/optimizer/defs.h" +#include "mongo/db/query/optimizer/opt_phase_manager.h" #include "mongo/db/query/optimizer/utils/utils.h" @@ -38,6 +39,11 @@ namespace mongo::optimizer { void maybePrintABT(const ABT& abt); +bool handleAutoUpdate(const std::string& expected, + const std::string& actual, + const std::string& fileName, + size_t lineNumber); + #define ASSERT_EXPLAIN(expected, abt) \ maybePrintABT(abt); \ ASSERT_EQ(expected, ExplainGenerator::explain(abt)) @@ -46,13 +52,36 @@ void maybePrintABT(const ABT& abt); maybePrintABT(abt); \ ASSERT_EQ(expected, ExplainGenerator::explainV2(abt)) +/** + * Auto update result back in the source file if the assert fails. + * The expected result must be a multi-line string in the following form: + * + * ASSERT_EXPLAIN_V2_AUTO( // NOLINT + * "BinaryOp [Add]\n" + * "| Const [2]\n" + * "Const [1]\n", + * tree); + * + * Limitations: + * 1. There should not be any comments or other formatting inside the multi-line string + * constant other than 'NOLINT'. If we have a single-line constant, the auto-updating will + * generate a 'NOLINT' at the end of the line. + * 2. The expression which we are explaining ('tree' in the example above) must fit on a single + * line. The macro should be indented by 4 spaces. + * + * TODO: SERVER-71004: Extend the usability of the auto-update macro. + */ +#define ASSERT_EXPLAIN_V2_AUTO(expected, abt) \ + maybePrintABT(abt); \ + ASSERT(handleAutoUpdate(expected, ExplainGenerator::explainV2(abt), __FILE__, __LINE__)) + #define ASSERT_EXPLAIN_V2Compact(expected, abt) \ maybePrintABT(abt); \ ASSERT_EQ(expected, ExplainGenerator::explainV2Compact(abt)) #define ASSERT_EXPLAIN_BSON(expected, abt) \ maybePrintABT(abt); \ - ASSERT_EQ(expected, ExplainGenerator::explainBSON(abt)) + ASSERT_EQ(expected, ExplainGenerator::explainBSONStr(abt)) #define ASSERT_EXPLAIN_PROPS_V2(expected, phaseManager) \ ASSERT_EQ(expected, \ @@ -93,4 +122,37 @@ IndexDefinition makeIndexDefinition(FieldNameType fieldName, IndexDefinition makeCompositeIndexDefinition(std::vector<TestIndexField> indexFields, bool isMultiKey = true); +/** + * A convenience factory function to create costing. + */ +std::unique_ptr<CostingInterface> makeCosting(); + +/** + * A convenience factory function to create OptPhaseManager for unit tests. + */ +OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + DebugInfo debugInfo, + QueryHints queryHints = {}); + +/** + * A convenience factory function to create OptPhaseManager for unit tests with CE hints. + */ +OptPhaseManager makePhaseManager(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + std::unique_ptr<CEInterface> ceDerivation, + DebugInfo debugInfo, + QueryHints queryHints = {}); + +/** + * A convenience factory function to create OptPhaseManager for unit tests which requires RID. + */ +OptPhaseManager makePhaseManagerRequireRID(OptPhaseManager::PhaseSet phaseSet, + PrefixId& prefixId, + Metadata metadata, + DebugInfo debugInfo, + QueryHints queryHints = {}); + } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/utils.cpp b/src/mongo/db/query/optimizer/utils/utils.cpp index 6c2dac94c96..a676e28179b 100644 --- a/src/mongo/db/query/optimizer/utils/utils.cpp +++ b/src/mongo/db/query/optimizer/utils/utils.cpp @@ -102,6 +102,7 @@ properties::LogicalProps createInitialScanProps(const ProjectionName& projection projectionName, scanDefName, true /*eqPredsOnly*/, + false /*hasProperInterval*/, {} /*satisfiedPartialIndexes*/), properties::CollectionAvailability({scanDefName}), properties::DistributionAvailability(std::move(distributions))); @@ -1928,4 +1929,14 @@ bool pathEndsInTraverse(const optimizer::ABT& path) { return optimizer::algebra::transport<false>(path, t); } +bool hasProperIntervals(const PartialSchemaRequirements& reqMap) { + // Compute if this node has any proper (not fully open) intervals. + for (const auto& [key, req] : reqMap) { + if (!isIntervalReqFullyOpenDNF(req.getIntervals())) { + return true; + } + } + return false; +} + } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/utils/utils.h b/src/mongo/db/query/optimizer/utils/utils.h index e4a2b2367a6..48589995a14 100644 --- a/src/mongo/db/query/optimizer/utils/utils.h +++ b/src/mongo/db/query/optimizer/utils/utils.h @@ -347,4 +347,5 @@ ABT lowerIntervals(PrefixId& prefixId, */ bool pathEndsInTraverse(const optimizer::ABT& path); +bool hasProperIntervals(const PartialSchemaRequirements& reqMap); } // namespace mongo::optimizer diff --git a/src/mongo/db/query/parsed_distinct.cpp b/src/mongo/db/query/parsed_distinct.cpp index 21781211771..03f2e3bf9e9 100644 --- a/src/mongo/db/query/parsed_distinct.cpp +++ b/src/mongo/db/query/parsed_distinct.cpp @@ -322,7 +322,8 @@ StatusWith<ParsedDistinct> ParsedDistinct::parse(OperationContext* opCtx, return ParsedDistinct(std::move(cq.getValue()), parsedDistinct.getKey().toString(), - parsedDistinct.getMirrored().value_or(false)); + parsedDistinct.getMirrored().value_or(false), + parsedDistinct.getSampleId()); } } // namespace mongo diff --git a/src/mongo/db/query/parsed_distinct.h b/src/mongo/db/query/parsed_distinct.h index b5cabef4680..49809d98440 100644 --- a/src/mongo/db/query/parsed_distinct.h +++ b/src/mongo/db/query/parsed_distinct.h @@ -55,8 +55,12 @@ public: ParsedDistinct(std::unique_ptr<CanonicalQuery> query, const std::string key, - const bool mirrored = false) - : _query(std::move(query)), _key(std::move(key)), _mirrored(std::move(mirrored)) {} + const bool mirrored = false, + const boost::optional<UUID> sampleId = boost::none) + : _query(std::move(query)), + _key(std::move(key)), + _mirrored(std::move(mirrored)), + _sampleId(std::move(sampleId)) {} const CanonicalQuery* getQuery() const { return _query.get(); @@ -74,6 +78,10 @@ public: return _key; } + boost::optional<UUID> getSampleId() const { + return _sampleId; + } + bool isMirrored() const { return _mirrored; } @@ -102,6 +110,9 @@ private: // Indicates that this was a mirrored operation. bool _mirrored = false; + + // The unique sample id for this operation if it has been chosen for sampling. + boost::optional<UUID> _sampleId; }; } // namespace mongo diff --git a/src/mongo/db/query/plan_cache_key_factory.cpp b/src/mongo/db/query/plan_cache_key_factory.cpp index 88d8f8a3665..c3a01418819 100644 --- a/src/mongo/db/query/plan_cache_key_factory.cpp +++ b/src/mongo/db/query/plan_cache_key_factory.cpp @@ -106,7 +106,7 @@ boost::optional<Timestamp> computeNewestVisibleIndexTimestamp(OperationContext* Timestamp currentNewestVisible = Timestamp::min(); auto ii = collection->getIndexCatalog()->getIndexIterator( - opCtx, IndexCatalog::InclusionPolicy::kReady | IndexCatalog::InclusionPolicy::kUnfinished); + opCtx, IndexCatalog::InclusionPolicy::kReady); while (ii->more()) { const IndexCatalogEntry* ice = ii->next(); auto minVisibleSnapshot = ice->getMinimumVisibleSnapshot(); diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl index fb8d93b9027..869ecbd1695 100644 --- a/src/mongo/db/query/query_knobs.idl +++ b/src/mongo/db/query/query_knobs.idl @@ -963,6 +963,15 @@ server_parameters: validator: callback: telemetry_util::validateTelemetryCacheSize + internalQueryDisableExclusionProjectionFastPath: + description: "If true, then ExclusionProjectionExecutor won't use fast path implementation. This + is needed to pass generational fuzzers that are sensitive to field order and other corner cases + when switching from Document to BSONObj." + set_at: [ startup ] + cpp_varname: "internalQueryDisableExclusionProjectionFastPath" + cpp_vartype: bool + default: false + test_only: true # Note for adding additional query knobs: # diff --git a/src/mongo/db/query/query_planner.cpp b/src/mongo/db/query/query_planner.cpp index 6902650a1ef..a3a63e66449 100644 --- a/src/mongo/db/query/query_planner.cpp +++ b/src/mongo/db/query/query_planner.cpp @@ -399,6 +399,11 @@ StatusWith<std::unique_ptr<QuerySolution>> tryToBuildColumnScan( } return Status{ErrorCodes::Error{6298502}, "columnstore index is not applicable for this query"}; } + +bool collscanIsBounded(const CollectionScanNode* collscan) { + return collscan->minRecord || collscan->maxRecord; +} + } // namespace using std::numeric_limits; @@ -615,13 +620,23 @@ static BSONObj finishMaxObj(const IndexEntry& indexEntry, } } +std::pair<std::unique_ptr<QuerySolution>, const CollectionScanNode*> buildCollscanSolnWithNode( + const CanonicalQuery& query, + bool tailable, + const QueryPlannerParams& params, + int direction = 1) { + std::unique_ptr<QuerySolutionNode> solnRoot( + QueryPlannerAccess::makeCollectionScan(query, tailable, params, direction)); + const auto* collscanNode = checked_cast<const CollectionScanNode*>(solnRoot.get()); + return std::make_pair( + QueryPlannerAnalysis::analyzeDataAccess(query, params, std::move(solnRoot)), collscanNode); +} + std::unique_ptr<QuerySolution> buildCollscanSoln(const CanonicalQuery& query, bool tailable, const QueryPlannerParams& params, int direction = 1) { - std::unique_ptr<QuerySolutionNode> solnRoot( - QueryPlannerAccess::makeCollectionScan(query, tailable, params, direction)); - return QueryPlannerAnalysis::analyzeDataAccess(query, params, std::move(solnRoot)); + return buildCollscanSolnWithNode(query, tailable, params, direction).first; } std::unique_ptr<QuerySolution> buildWholeIXSoln( @@ -1518,6 +1533,8 @@ StatusWith<std::vector<std::unique_ptr<QuerySolution>>> QueryPlanner::plan( "No indexed plans available, and running with 'notablescan'"); } + bool clusteredCollection = params.clusteredInfo.has_value(); + // geoNear and text queries *require* an index. // Also, if a hint is specified it indicates that we MUST use it. bool possibleToCollscan = @@ -1527,21 +1544,23 @@ StatusWith<std::vector<std::unique_ptr<QuerySolution>>> QueryPlanner::plan( return Status(ErrorCodes::NoQueryExecutionPlans, "No query solutions"); } - if (possibleToCollscan && (collscanRequested || collScanRequired)) { - auto collscan = buildCollscanSoln(query, isTailable, params); - if (!collscan && collScanRequired) { + if (possibleToCollscan && (collscanRequested || collScanRequired || clusteredCollection)) { + auto [collscanSoln, collscanNode] = buildCollscanSolnWithNode(query, isTailable, params); + if (!collscanSoln && collScanRequired) { return Status(ErrorCodes::NoQueryExecutionPlans, "Failed to build collection scan soln"); } - if (collscan) { + + if (collscanSoln && + (collscanRequested || collScanRequired || collscanIsBounded(collscanNode))) { LOGV2_DEBUG(20984, 5, "Planner: outputting a collection scan", - "collectionScan"_attr = redact(collscan->toString())); + "collectionScan"_attr = redact(collscanSoln->toString())); SolutionCacheData* scd = new SolutionCacheData(); scd->solnType = SolutionCacheData::COLLSCAN_SOLN; - collscan->cacheData.reset(scd); - out.push_back(std::move(collscan)); + collscanSoln->cacheData.reset(scd); + out.push_back(std::move(collscanSoln)); } } diff --git a/src/mongo/db/query/query_solution.h b/src/mongo/db/query/query_solution.h index a24cff08f2f..d33794340fe 100644 --- a/src/mongo/db/query/query_solution.h +++ b/src/mongo/db/query/query_solution.h @@ -1396,15 +1396,15 @@ struct GroupNode : public QuerySolutionNode { shouldProduceBson(shouldProduceBson) { // Use the DepsTracker to extract the fields that the 'groupByExpression' and accumulator // expressions depend on. - for (auto& groupByExprField : expression::getDependencies(groupByExpression.get()).fields) { - requiredFields.insert(groupByExprField); - } + DepsTracker deps; + expression::addDependencies(groupByExpression.get(), &deps); for (auto&& acc : accumulators) { - auto argExpr = acc.expr.argument; - for (auto& argExprField : expression::getDependencies(argExpr.get()).fields) { - requiredFields.insert(argExprField); - } + expression::addDependencies(acc.expr.argument.get(), &deps); } + + requiredFields = deps.fields; + needWholeDocument = deps.needWholeDocument; + needsAnyMetadata = deps.getNeedsAnyMetadata(); } StageType getType() const override { @@ -1438,7 +1438,9 @@ struct GroupNode : public QuerySolutionNode { // Carries the fields this GroupNode depends on. Namely, 'requiredFields' contains the union of // the fields in the 'groupByExpressions' and the fields in the input Expressions of the // 'accumulators'. - StringSet requiredFields; + OrderedPathSet requiredFields; + bool needWholeDocument = false; + bool needsAnyMetadata = false; // If set to true, generated SBE plan will produce result as BSON object. If false, // 'sbe::Object' is produced instead. diff --git a/src/mongo/db/query/sbe_shard_filter_test.cpp b/src/mongo/db/query/sbe_shard_filter_test.cpp index 82585905f6e..a5d49bf4849 100644 --- a/src/mongo/db/query/sbe_shard_filter_test.cpp +++ b/src/mongo/db/query/sbe_shard_filter_test.cpp @@ -234,7 +234,7 @@ TEST_F(SbeShardFilterTest, MissingFieldsAtBottomDottedPathFilledCorrectly) { TEST_F(SbeShardFilterTest, CoveredShardFilterPlan) { auto indexKeyPattern = BSON("a" << 1 << "b" << 1 << "c" << 1 << "d" << 1); - auto projection = BSON("a" << 1 << "c" << 1); + auto projection = BSON("a" << 1 << "c" << 1 << "_id" << 0); auto mockedIndexKeys = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 2 << "d" << 2)), BSON_ARRAY(BSON("a" << 3 << "b" << 3 << "c" << 3 << "d" << 3))}; diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index ce8344a0d36..a196821e3dc 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -84,129 +84,9 @@ #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kQuery - namespace mongo::stage_builder { namespace { /** - * For covered projections, each of the projection field paths represent respective index key. To - * rehydrate index keys into the result object, we first need to convert projection AST into - * 'IndexKeyPatternTreeNode' structure. Context structure and visitors below are used for this - * purpose. - */ -struct IndexKeysBuilderContext { - // Contains resulting tree of index keys converted from projection AST. - IndexKeyPatternTreeNode root; - - // Full field path of the currently visited projection node. - std::vector<StringData> currentFieldPath; - - // Each projection node has a vector of field names. This stack contains indexes of the - // currently visited field names for each of the projection nodes. - std::vector<size_t> currentFieldIndex; -}; - -/** - * Covered projections are always inclusion-only, so we ban all other operators. - */ -class IndexKeysBuilder : public projection_ast::ProjectionASTConstVisitor { -public: - using projection_ast::ProjectionASTConstVisitor::visit; - - IndexKeysBuilder(IndexKeysBuilderContext* context) : _context{context} {} - - void visit(const projection_ast::ProjectionPositionalASTNode* node) final { - tasserted(5474501, "Positional projection is not allowed in covered projection"); - } - - void visit(const projection_ast::ProjectionSliceASTNode* node) final { - tasserted(5474502, "$slice is not allowed in covered projection"); - } - - void visit(const projection_ast::ProjectionElemMatchASTNode* node) final { - tasserted(5474503, "$elemMatch is not allowed in covered projection"); - } - - void visit(const projection_ast::ExpressionASTNode* node) final { - tasserted(5474504, "Expressions are not allowed in covered projection"); - } - - void visit(const projection_ast::MatchExpressionASTNode* node) final { - tasserted(5474505, - "$elemMatch and positional projection are not allowed in covered projection"); - } - - void visit(const projection_ast::BooleanConstantASTNode* node) override {} - -protected: - IndexKeysBuilderContext* _context; -}; - -class IndexKeysPreBuilder final : public IndexKeysBuilder { -public: - using IndexKeysBuilder::IndexKeysBuilder; - using IndexKeysBuilder::visit; - - void visit(const projection_ast::ProjectionPathASTNode* node) final { - _context->currentFieldIndex.push_back(0); - _context->currentFieldPath.emplace_back(node->fieldNames().front()); - } -}; - -class IndexKeysInBuilder final : public IndexKeysBuilder { -public: - using IndexKeysBuilder::IndexKeysBuilder; - using IndexKeysBuilder::visit; - - void visit(const projection_ast::ProjectionPathASTNode* node) final { - auto& currentIndex = _context->currentFieldIndex.back(); - currentIndex++; - _context->currentFieldPath.back() = node->fieldNames()[currentIndex]; - } -}; - -class IndexKeysPostBuilder final : public IndexKeysBuilder { -public: - using IndexKeysBuilder::IndexKeysBuilder; - using IndexKeysBuilder::visit; - - void visit(const projection_ast::ProjectionPathASTNode* node) final { - _context->currentFieldIndex.pop_back(); - _context->currentFieldPath.pop_back(); - } - - void visit(const projection_ast::BooleanConstantASTNode* constantNode) final { - if (!constantNode->value()) { - // Even though only inclusion is allowed in covered projection, there still can be - // {_id: 0} component. We do not need to generate any nodes for it. - return; - } - - // Insert current field path into the index keys tree if it does not exist yet. - auto* node = &_context->root; - for (const auto& part : _context->currentFieldPath) { - if (auto it = node->children.find(part); it != node->children.end()) { - node = it->second.get(); - } else { - node = node->emplace(part); - } - } - } -}; - -sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs, const PlanStageSlots& outputs) { - std::vector<std::pair<StringData, sbe::value::SlotId>> pairs; - outputs.forEachSlot( - reqs, [&](auto&& slot, const StringData& name) { pairs.emplace_back(name, slot); }); - std::sort(pairs.begin(), pairs.end()); - - auto outputSlots = sbe::makeSV(); - for (auto&& p : pairs) { - outputSlots.emplace_back(p.second); - } - return outputSlots; -} - -/** * Generates an EOF plan. Note that even though this plan will return nothing, it will still define * the slots specified by 'reqs'. */ @@ -271,6 +151,26 @@ std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment( return env; } +sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs, + const PlanStageSlots& outputs, + const sbe::value::SlotVector& exclude) { + auto excludeSet = sbe::value::SlotSet{exclude.begin(), exclude.end()}; + + std::vector<std::pair<PlanStageSlots::Name, sbe::value::SlotId>> pairs; + outputs.forEachSlot(reqs, [&](auto&& slot, const PlanStageSlots::Name& name) { + if (!excludeSet.count(slot)) { + pairs.emplace_back(name, slot); + } + }); + std::sort(pairs.begin(), pairs.end()); + + auto outputSlots = sbe::makeSV(); + for (auto&& p : pairs) { + outputSlots.emplace_back(p.second); + } + return outputSlots; +} + void prepareSlotBasedExecutableTree(OperationContext* opCtx, sbe::PlanStage* root, PlanStageData* data, @@ -524,22 +424,22 @@ std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolution std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildCollScan( const QuerySolutionNode* root, const PlanStageReqs& reqs) { - invariant(!reqs.getIndexKeyBitset()); + tassert(6023400, "buildCollScan() does not support kKey", !reqs.hasKeys()); + auto fields = reqs.getFields(); auto csn = static_cast<const CollectionScanNode*>(root); auto [stage, outputs] = generateCollScan(_state, getCurrentCollection(reqs), csn, + fields, _yieldPolicy, reqs.getIsTailableCollScanResumeBranch()); if (reqs.has(kReturnKey)) { // Assign the 'returnKeySlot' to be the empty object. outputs.set(kReturnKey, _slotIdGenerator.generate()); - stage = sbe::makeProjectStage(std::move(stage), - root->nodeId(), - outputs.get(kReturnKey), - sbe::makeE<sbe::EFunction>("newObj", sbe::makeEs())); + stage = sbe::makeProjectStage( + std::move(stage), root->nodeId(), outputs.get(kReturnKey), makeFunction("newObj"_sd)); } // Don't advertize the RecordId output if none of our ancestors are going to use it. if (!reqs.has(kRecordId)) { @@ -552,12 +452,20 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildVirtualScan( const QuerySolutionNode* root, const PlanStageReqs& reqs) { using namespace std::literals; + auto vsn = static_cast<const VirtualScanNode*>(root); - // The caller should only have requested components of the index key if the virtual scan is - // mocking an index scan. - if (vsn->scanType == VirtualScanNode::ScanType::kCollScan) { - invariant(!reqs.getIndexKeyBitset()); - } + auto reqKeys = reqs.getKeys(); + + // The caller should only request kKey slots if the virtual scan is mocking an index scan. + tassert(6023401, + "buildVirtualScan() does not support kKey when 'scanType' is not ixscan", + vsn->scanType == VirtualScanNode::ScanType::kIxscan || reqKeys.empty()); + + tassert(6023423, + "buildVirtualScan() does not support dotted paths for kKey slots", + std::all_of(reqKeys.begin(), reqKeys.end(), [](auto&& s) { + return s.find('.') == std::string::npos; + })); auto [inputTag, inputVal] = sbe::value::makeNewArray(); sbe::value::ValueGuard inputGuard{inputTag, inputVal}; @@ -572,7 +480,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder } inputGuard.reset(); - auto [scanSlots, stage] = + auto [scanSlots, scanStage] = generateVirtualScanMulti(&_slotIdGenerator, vsn->hasRecordId ? 2 : 1, inputTag, inputVal); sbe::value::SlotId resultSlot; @@ -586,41 +494,27 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder PlanStageSlots outputs; - if (reqs.has(kResult)) { + if (reqs.has(kResult) || reqs.hasFields()) { outputs.set(kResult, resultSlot); - } else if (reqs.getIndexKeyBitset()) { - // The caller wanted individual slots for certain components of a mock index scan. Use a - // project stage to produce those slots. Since the test will represent index keys as BSON - // objects, we use 'getField' expressions to extract the necessary fields. - invariant(!vsn->indexKeyPattern.isEmpty()); - - sbe::value::SlotVector indexKeySlots; - sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projections; - - size_t indexKeyPos = 0; - for (auto&& field : vsn->indexKeyPattern) { - if (reqs.getIndexKeyBitset()->test(indexKeyPos)) { - indexKeySlots.push_back(_slotIdGenerator.generate()); - projections.emplace(indexKeySlots.back(), - makeFunction("getField"_sd, - sbe::makeE<sbe::EVariable>(resultSlot), - makeConstant(field.fieldName()))); - } - ++indexKeyPos; - } - - stage = - sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projections), root->nodeId()); - - outputs.setIndexKeySlots(indexKeySlots); } - if (reqs.has(kRecordId)) { invariant(vsn->hasRecordId); invariant(scanSlots.size() == 2); outputs.set(kRecordId, scanSlots[0]); } + auto stage = std::move(scanStage); + + // The caller wants individual slots for certain components of the mock index scan. Retrieve + // the values for these paths and project them to slots. + auto [projectStage, slots] = projectTopLevelFields( + std::move(stage), reqKeys, resultSlot, root->nodeId(), &_slotIdGenerator); + stage = std::move(projectStage); + + for (size_t i = 0; i < reqKeys.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kKey, std::move(reqKeys[i])), slots[i]); + } + return {std::move(stage), std::move(outputs)}; } @@ -629,17 +523,35 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto ixn = static_cast<const IndexScanNode*>(root); invariant(reqs.has(kReturnKey) || !ixn->addKeyMetadata); - sbe::IndexKeysInclusionSet indexKeyBitset; + auto reqKeys = reqs.getKeys(); + auto reqKeysSet = StringDataSet{reqKeys.begin(), reqKeys.end()}; - if (reqs.has(PlanStageSlots::kReturnKey) || reqs.has(PlanStageSlots::kResult)) { - // If either 'reqs.result' or 'reqs.returnKey' is true, we need to get all parts of the - // index key (regardless of what was requested by 'reqs.indexKeyBitset') so that we can - // create the inflated index key (keyExpr). - for (int i = 0; i < ixn->index.keyPattern.nFields(); ++i) { - indexKeyBitset.set(i); + std::vector<StringData> keys; + sbe::IndexKeysInclusionSet keysBitset; + StringDataSet indexKeyPatternSet; + size_t i = 0; + for (const auto& elt : ixn->index.keyPattern) { + StringData name = elt.fieldNameStringData(); + indexKeyPatternSet.emplace(name); + if (reqKeysSet.count(name)) { + keysBitset.set(i); + keys.emplace_back(name); } - } else if (reqs.getIndexKeyBitset()) { - indexKeyBitset = *reqs.getIndexKeyBitset(); + ++i; + } + + auto additionalKeys = + filterVector(reqKeys, [&](const std::string& s) { return !indexKeyPatternSet.count(s); }); + + sbe::IndexKeysInclusionSet indexKeyBitset; + if (reqs.has(kReturnKey) || reqs.has(kResult) || reqs.hasFields()) { + // If either 'reqs.result' or 'reqs.returnKey' or 'reqs.hasFields()' is true, we need to + // get all parts of the index key so that we can create the inflated index key. + for (int j = 0; j < ixn->index.keyPattern.nFields(); ++j) { + indexKeyBitset.set(j); + } + } else { + indexKeyBitset = keysBitset; } // If the slots necessary for performing an index consistency check were not requested in @@ -652,28 +564,32 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder const auto generateIndexScanFunc = ixn->iets.empty() ? generateIndexScan : generateIndexScanWithDynamicBounds; - auto&& [stage, outputs] = generateIndexScanFunc(_state, - getCurrentCollection(reqs), - ixn, - indexKeyBitset, - _yieldPolicy, - iamMap, - reqs.has(kIndexKeyPattern)); + auto&& [scanStage, scanOutputs] = generateIndexScanFunc(_state, + getCurrentCollection(reqs), + ixn, + indexKeyBitset, + _yieldPolicy, + iamMap, + reqs.has(kIndexKeyPattern)); + + auto stage = std::move(scanStage); + auto outputs = std::move(scanOutputs); // Remove the RecordId from the output if we were not requested to produce it. if (!reqs.has(PlanStageSlots::kRecordId) && outputs.has(kRecordId)) { outputs.clear(kRecordId); } - if (reqs.has(PlanStageSlots::kReturnKey)) { - sbe::EExpression::Vector mkObjArgs; - size_t i = 0; + if (reqs.has(PlanStageSlots::kReturnKey)) { + sbe::EExpression::Vector args; for (auto&& elem : ixn->index.keyPattern) { - mkObjArgs.emplace_back(sbe::makeE<sbe::EConstant>(elem.fieldNameStringData())); - mkObjArgs.emplace_back(sbe::makeE<sbe::EVariable>((*outputs.getIndexKeySlots())[i++])); + StringData name = elem.fieldNameStringData(); + args.emplace_back(sbe::makeE<sbe::EConstant>(name)); + args.emplace_back( + makeVariable(outputs.get(std::make_pair(PlanStageSlots::kKey, name)))); } - auto rawKeyExpr = sbe::makeE<sbe::EFunction>("newObj", std::move(mkObjArgs)); + auto rawKeyExpr = sbe::makeE<sbe::EFunction>("newObj"_sd, std::move(args)); outputs.set(PlanStageSlots::kReturnKey, _slotIdGenerator.generate()); stage = sbe::makeProjectStage(std::move(stage), ixn->nodeId(), @@ -681,25 +597,30 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::move(rawKeyExpr)); } - if (reqs.has(PlanStageSlots::kResult)) { - outputs.set(PlanStageSlots::kResult, _slotIdGenerator.generate()); - stage = rehydrateIndexKey(std::move(stage), - ixn->index.keyPattern, - ixn->nodeId(), - *outputs.getIndexKeySlots(), - outputs.get(PlanStageSlots::kResult)); + if (reqs.has(kResult) || reqs.hasFields()) { + auto indexKeySlots = sbe::makeSV(); + for (auto&& elem : ixn->index.keyPattern) { + StringData name = elem.fieldNameStringData(); + indexKeySlots.emplace_back(outputs.get(std::make_pair(PlanStageSlots::kKey, name))); + } + + auto resultSlot = _slotIdGenerator.generate(); + outputs.set(kResult, resultSlot); + + stage = rehydrateIndexKey( + std::move(stage), ixn->index.keyPattern, ixn->nodeId(), indexKeySlots, resultSlot); } - if (reqs.getIndexKeyBitset()) { - outputs.setIndexKeySlots( - makeIndexKeyOutputSlotsMatchingParentReqs(ixn->index.keyPattern, - *reqs.getIndexKeyBitset(), - indexKeyBitset, - *outputs.getIndexKeySlots())); - } else { - outputs.setIndexKeySlots(boost::none); + auto [outStage, nothingSlots] = projectNothingToSlots( + std::move(stage), additionalKeys.size(), root->nodeId(), &_slotIdGenerator); + stage = std::move(outStage); + for (size_t i = 0; i < additionalKeys.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kKey, std::move(additionalKeys[i])), + nothingSlots[i]); } + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } @@ -838,9 +759,7 @@ std::unique_ptr<sbe::EExpression> generatePerColumnFilterExpr(StageBuilderState& std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildColumnScan( const QuerySolutionNode* root, const PlanStageReqs& reqs) { - tassert(6312404, - "Unexpected index key bitset provided for column scan stage", - !reqs.getIndexKeyBitset()); + tassert(6023403, "buildColumnScan() does not support kKey", !reqs.hasKeys()); auto csn = static_cast<const ColumnIndexScanNode*>(root); tassert(6312405, @@ -976,19 +895,22 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder const QuerySolutionNode* root, const PlanStageReqs& reqs) { auto fn = static_cast<const FetchNode*>(root); - // The child must produce all of the slots required by the parent of this FetchNode, except for - // 'resultSlot' which will be produced by the call to makeLoopJoinForFetch() below. In addition - // to that, the child must always produce a 'recordIdSlot' because it's needed for the call to - // makeLoopJoinForFetch() below. + // The child must produce a kRecordId slot, as well as all the kMeta and kKey slots required + // by the parent of this FetchNode except for 'resultSlot'. Note that the child does _not_ + // need to produce any kField slots. Any kField requests by the parent will be handled by the + // logic below. + auto child = fn->children[0].get(); + auto childReqs = reqs.copy() .clear(kResult) + .clearAllFields() .set(kRecordId) .set(kSnapshotId) .set(kIndexId) .set(kIndexKey) .set(kIndexKeyPattern); - auto [stage, outputs] = build(fn->children[0].get(), childReqs); + auto [stage, outputs] = build(child, childReqs); auto iamMap = _data.iamMap; uassert(4822880, "RecordId slot is not defined", outputs.has(kRecordId)); @@ -999,44 +921,48 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder uassert(5290711, "Index key slot is not defined", outputs.has(kIndexKey)); uassert(5113713, "Index key pattern slot is not defined", outputs.has(kIndexKeyPattern)); - auto forwardingReqs = reqs.copy().clear(kResult).clear(kRecordId); + auto forwardingReqs = reqs.copy().clear(kResult).clear(kRecordId).clearAllFields(); auto relevantSlots = getSlotsToForward(forwardingReqs, outputs); - // Forward slots for components of the index key if our parent requested them. - if (auto indexKeySlots = outputs.getIndexKeySlots()) { - relevantSlots.insert(relevantSlots.end(), indexKeySlots->begin(), indexKeySlots->end()); - } - - sbe::value::SlotId fetchResultSlot, fetchRecordIdSlot; - std::tie(fetchResultSlot, fetchRecordIdSlot, stage) = - makeLoopJoinForFetch(std::move(stage), - outputs.get(kRecordId), - outputs.get(kSnapshotId), - outputs.get(kIndexId), - outputs.get(kIndexKey), - outputs.get(kIndexKeyPattern), - getCurrentCollection(reqs), - std::move(iamMap), - root->nodeId(), - std::move(relevantSlots), - _slotIdGenerator); + auto fields = reqs.getFields(); + + auto childRecordId = outputs.get(kRecordId); + auto fetchResultSlot = _slotIdGenerator.generate(); + auto fetchRecordIdSlot = _slotIdGenerator.generate(); + auto fieldSlots = _slotIdGenerator.generateMultiple(fields.size()); + + stage = makeLoopJoinForFetch(std::move(stage), + fetchResultSlot, + fetchRecordIdSlot, + fields, + fieldSlots, + childRecordId, + outputs.get(kSnapshotId), + outputs.get(kIndexId), + outputs.get(kIndexKey), + outputs.get(kIndexKeyPattern), + getCurrentCollection(reqs), + std::move(iamMap), + root->nodeId(), + std::move(relevantSlots)); outputs.set(kResult, fetchResultSlot); - // Propagate the RecordId output only if requested. + + // Only propagate kRecordId if requested. if (reqs.has(kRecordId)) { outputs.set(kRecordId, fetchRecordIdSlot); } else { outputs.clear(kRecordId); } + for (size_t i = 0; i < fields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]); + } + if (fn->filter) { forwardingReqs = reqs.copy().set(kResult); - auto relevantSlots = getSlotsToForward(forwardingReqs, outputs); - // Forward slots for components of the index key if our parent requested them. - if (auto indexKeySlots = outputs.getIndexKeySlots()) { - relevantSlots.insert(relevantSlots.end(), indexKeySlots->begin(), indexKeySlots->end()); - } + auto relevantSlots = getSlotsToForward(forwardingReqs, outputs); auto [_, outputStage] = generateFilter(_state, fn->filter.get(), @@ -1046,6 +972,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder stage = outputStage.extractStage(root->nodeId()); } + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } @@ -1197,22 +1125,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder sortPattern.size() > 0); auto child = sn->children[0].get(); - if (reqs.getIndexKeyBitset().has_value() || - (!reqs.has(kResult) && child->getType() == STAGE_IXSCAN)) { - // We decide to IndexKeySlots when building the child if: - // 1) The query is covered; or - // 2) This is a SORT->IXSCAN and kResult wasn't requested. + + if (auto [ixn, ct] = getFirstNodeByType(root, STAGE_IXSCAN); + !sn->fetched() && !reqs.has(kResult) && ixn && ct == 1) { return buildSortCovered(root, reqs); } - // The child must produce the kResult slot as well as all slots required by the parent of - // this SortNode. - auto childReqs = reqs.copy().set(kResult); - - auto [stage, outputs] = build(child, childReqs); - - auto collatorSlot = _data.env->getSlotIfExists("collator"_sd); - StringDataSet prefixSet; bool hasPartsWithCommonPrefix = false; for (const auto& part : sortPattern) { @@ -1226,14 +1144,18 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder } } + auto childReqs = reqs.copy().set(kResult); + auto [stage, childOutputs] = build(child, childReqs); + auto outputs = std::move(childOutputs); + + auto collatorSlot = _data.env->getSlotIfExists("collator"_sd); + sbe::value::SlotVector orderBy; std::vector<sbe::value::SortDirection> direction; sbe::value::SlotId outputSlotId = outputs.get(kResult); if (!hasPartsWithCommonPrefix) { // Handle the case where we are using kResult and there are no common prefixes. - sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projectMap; - orderBy.reserve(sortPattern.size()); // Sorting has a limitation where only one of the sort patterns can involve arrays. @@ -1329,8 +1251,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder makeConstant(sbe::value::TypeTags::sortSpec, sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec.release())); - auto collatorSlot = _data.env->getSlotIfExists("collator"_sd); - // generateSortKey() will handle the parallel arrays check and sort key traversal for us, // so we don't need to generate our own sort key traversal logic in the SBE plan. stage = sbe::makeProjectStage(std::move(stage), @@ -1347,7 +1267,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // Slots for sort stage to forward to parent stage. Values in these slots are not used during // sorting. - auto forwardedSlots = getSlotsToForward(childReqs, outputs); + auto forwardedSlots = getSlotsToForward(childReqs, outputs, orderBy); stage = sbe::makeS<sbe::SortStage>(std::move(stage), @@ -1359,52 +1279,40 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder _cq.getExpCtx()->allowDiskUse, root->nodeId()); + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildSortCovered( const QuerySolutionNode* root, const PlanStageReqs& reqs) { + tassert(6023404, "buildSortCovered() does not support kResult", !reqs.has(kResult)); + const auto sn = static_cast<const SortNode*>(root); auto sortPattern = SortPattern{sn->pattern, _cq.getExpCtx()}; tassert(7047600, "QueryPlannerAnalysis should not produce a SortNode with an empty sort pattern", sortPattern.size() > 0); + tassert(6023422, "buildSortCovered() expected 'sn' to not be fetched", !sn->fetched()); - // The child must produce all of the slots required by the parent of this SortNode. - auto childReqs = reqs.copy(); auto child = sn->children[0].get(); - - auto parentIndexKeyBitset = reqs.getIndexKeyBitset().get_value_or({}); - BSONObj indexKeyPattern; - sbe::IndexKeysInclusionSet sortPatternKeyBitset; - StringMap<size_t> sortSlotPosMap; - - // Set IndexKeyBitset to request each part of the index requested by the parent and to request - // each part of the sort pattern. auto indexScan = static_cast<const IndexScanNode*>(getLoneNodeByType(child, STAGE_IXSCAN)); tassert(7047601, "Expected index scan below sort", indexScan); - indexKeyPattern = indexScan->index.keyPattern; + auto indexKeyPattern = indexScan->index.keyPattern; + + // The child must produce all of the slots required by the parent of this SortNode. + auto childReqs = reqs.copy(); + std::vector<std::string> keys; StringDataSet sortPathsSet; for (const auto& part : sortPattern) { - sortPathsSet.insert(part.fieldPath->fullPath()); + const auto& key = part.fieldPath->fullPath(); + keys.emplace_back(key); + sortPathsSet.emplace(key); } - // 'sortPatternKeyBitset' will contain a bit pattern indicating which parts of the index - // pattern are needed for sorting. 'sortPaths' will contain the field paths used in the - // sort pattern, ordered according to the index pattern. - std::vector<std::string> sortPaths; - std::tie(sortPatternKeyBitset, sortPaths) = - makeIndexKeyInclusionSet(indexKeyPattern, sortPathsSet); - childReqs.getIndexKeyBitset() = sortPatternKeyBitset | parentIndexKeyBitset; - - // Build a map that maps each field path to its position in 'sortPaths'. This will help us - // later when we need to convert the output of makeIndexKeyOutputSlotsMatchingParentReqs() - // from the index pattern's order to sort pattern's order. - for (size_t i = 0; i < sortPaths.size(); ++i) { - sortSlotPosMap.emplace(std::move(sortPaths[i]), i); - } + childReqs.setKeys(std::move(keys)); auto [stage, outputs] = build(child, childReqs); @@ -1412,26 +1320,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder sbe::value::SlotVector orderBy; std::vector<sbe::value::SortDirection> direction; - - // Handle the case where we are using IndexKeySlots. - auto indexKeySlots = *outputs.extractIndexKeySlots(); - - // Currently, 'indexKeySlots' contains slots for two kinds of index keys: - // 1. Keys requested for sort pattern - // 2. Keys requested by parent (if any) - // The first category of slots needs to go into the 'orderBy' vector. The second category - // of slots goes into 'outputs' to be used by the parent. - auto& childIndexKeyBitset = *childReqs.getIndexKeyBitset(); - - // The query planner does not support covered queries or SORT->IXSCAN plans on array fields - // for multikey indexes. Therefore we don't have to generate the parallel arrays check or - // the sort key traversal logic. - auto sortIndexKeySlots = makeIndexKeyOutputSlotsMatchingParentReqs( - indexKeyPattern, sortPatternKeyBitset, childIndexKeyBitset, indexKeySlots); - - // 'sortIndexKeySlots' is ordered according to the index pattern, but 'orderBy' needs to - // be ordered according to the sort pattern. We use 'sortIndexKeySlots' to convert from - // the index pattern's order to the sort pattern's order. orderBy.reserve(sortPattern.size()); direction.reserve(sortPattern.size()); for (const auto& part : sortPattern) { @@ -1439,19 +1327,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // contains $meta, so this assertion should always be true. tassert(7047602, "Sort with $meta is not supported in SBE", part.fieldPath); - auto it = sortSlotPosMap.find(part.fieldPath->fullPath()); - tassert(6843200, - str::stream() << "Did not find sort path '" << part.fieldPath->fullPath() - << "' in sort path map", - it != sortSlotPosMap.end()); - - auto slotPos = it->second; - tassert(6843201, - str::stream() << "Sort path map for '" << part.fieldPath->fullPath() - << "' returned an index '" << slotPos << "' that is out of bounds", - slotPos < sortIndexKeySlots.size()); - - orderBy.push_back(sortIndexKeySlots[slotPos]); + orderBy.push_back( + outputs.get(std::make_pair(PlanStageSlots::kKey, part.fieldPath->fullPath()))); direction.push_back(part.isAscending ? sbe::value::SortDirection::Ascending : sbe::value::SortDirection::Descending); } @@ -1478,24 +1355,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // Slots for sort stage to forward to parent stage. Values in these slots are not used during // sorting. - auto forwardedSlots = getSlotsToForward(childReqs, outputs); - - if (parentIndexKeyBitset.any()) { - auto parentIndexKeySlots = makeIndexKeyOutputSlotsMatchingParentReqs( - indexKeyPattern, parentIndexKeyBitset, childIndexKeyBitset, indexKeySlots); - - // The 'forwardedSlots' vector should include all slots requested by parent excluding - // any slots that appear in the 'orderBy' vector. - sbe::value::SlotSet orderBySlotsSet(orderBy.begin(), orderBy.end()); - for (auto slot : parentIndexKeySlots) { - if (!orderBySlotsSet.count(slot)) { - forwardedSlots.push_back(slot); - } - } - - // Make sure to store all of the slots requested by the parent into 'outputs'. - outputs.setIndexKeySlots(std::move(parentIndexKeySlots)); - } + auto forwardedSlots = getSlotsToForward(childReqs, outputs, orderBy); stage = sbe::makeS<sbe::SortStage>(std::move(stage), @@ -1507,11 +1367,13 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder _cq.getExpCtx()->allowDiskUse, root->nodeId()); + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> -SlotBasedStageBuilder::buildSortKeyGeneraror(const QuerySolutionNode* root, +SlotBasedStageBuilder::buildSortKeyGenerator(const QuerySolutionNode* root, const PlanStageReqs& reqs) { uasserted(4822883, "Sort key generator in not supported in SBE yet"); } @@ -1534,64 +1396,21 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::vector<sbe::value::SlotVector> inputKeys; std::vector<sbe::value::SlotVector> inputVals; - // Children must produce all of the slots required by the parent of this SortMergeNode. In - // addition, children must always produce a 'recordIdSlot' if the 'dedup' flag is true. - auto childReqs = reqs.copy().setIf(kRecordId, mergeSortNode->dedup); + std::vector<std::string> keys; + StringSet sortPatternSet; + for (auto&& sortPart : sortPattern) { + sortPatternSet.emplace(sortPart.fieldPath->fullPath()); + keys.emplace_back(sortPart.fieldPath->fullPath()); + } - // If a parent node has requested an index key bitset, then we will produce index keys - // corresponding to the sort pattern parts needed by said parent node. - bool parentRequestsIdxKeys = reqs.getIndexKeyBitset().has_value(); + // Children must produce all of the slots required by the parent of this SortMergeNode. In + // addition, children must always produce a 'recordIdSlot' if the 'dedup' flag is true, and + // they must produce kKey slots for each part of the sort pattern. + auto childReqs = reqs.copy().setIf(kRecordId, mergeSortNode->dedup).setKeys(std::move(keys)); for (auto&& child : mergeSortNode->children) { sbe::value::SlotVector inputKeysForChild; - // Retrieve the sort pattern provided by the subtree rooted at 'child'. In particular, if - // our child is a MERGE_SORT, it will provide the sort directly. If instead it's a tree - // containing a lone IXSCAN, we have to check the key pattern because 'providedSorts()' may - // or may not provide the baseSortPattern depending on the index bounds (in particular, - // if the bounds are fixed, the fields will be marked as 'ignored'). Otherwise, we attempt - // to retrieve it from 'providedSorts'. - auto childSortPattern = [&]() { - if (auto [msn, _] = getFirstNodeByType(child.get(), STAGE_SORT_MERGE); msn) { - auto node = static_cast<const MergeSortNode*>(msn); - return node->sort; - } else { - auto [ixn, ct] = getFirstNodeByType(child.get(), STAGE_IXSCAN); - if (ixn && ct == 1) { - auto node = static_cast<const IndexScanNode*>(ixn); - return node->index.keyPattern; - } - } - auto baseSort = child->providedSorts().getBaseSortPattern(); - tassert(6149600, - str::stream() << "Did not find sort pattern for child " << child->toString(), - !baseSort.isEmpty()); - return baseSort; - }(); - - // Map of field name to position within the index key. This is used to account for - // mismatches between the sort pattern and the index key pattern. For instance, suppose - // the requested sort is {a: 1, b: 1} and the index key pattern is {c: 1, b: 1, a: 1}. - // When the slots for the relevant components of the index key are generated (i.e. - // extract keys for 'b' and 'a'), we wish to insert them into 'inputKeys' in the order - // that they appear in the sort pattern. - StringMap<size_t> indexKeyPositionMap; - - sbe::IndexKeysInclusionSet indexKeyBitset; - size_t i = 0; - for (auto&& elt : childSortPattern) { - for (auto&& sortPart : sortPattern) { - auto path = sortPart.fieldPath->fullPath(); - if (elt.fieldNameStringData() == path) { - indexKeyBitset.set(i); - indexKeyPositionMap.emplace(path, indexKeyPositionMap.size()); - break; - } - } - ++i; - } - childReqs.getIndexKeyBitset() = indexKeyBitset; - // Children must produce a 'resultSlot' if they produce fetched results. auto [stage, outputs] = build(child.get(), childReqs); @@ -1600,30 +1419,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder " if the 'dedup' flag is set", !mergeSortNode->dedup || outputs.has(kRecordId)); - // Clear the index key bitset after building the child stage. - childReqs.getIndexKeyBitset() = boost::none; - - // Insert the index key slots in the order of the sort pattern. - auto indexKeys = outputs.extractIndexKeySlots(); - tassert(5184302, - "SORT_MERGE must receive index key slots as input from its child stages", - indexKeys); - for (const auto& part : sortPattern) { - auto partPath = part.fieldPath->fullPath(); - auto index = indexKeyPositionMap.find(partPath); - tassert(5184303, - str::stream() << "Could not find index key position for sort key part " - << partPath, - index != indexKeyPositionMap.end()); - auto indexPos = index->second; - tassert(5184304, - str::stream() << "Index position " << indexPos - << " is not less than number of index components " - << indexKeys->size(), - indexPos < indexKeys->size()); - auto indexKeyPart = indexKeys->at(indexPos); - inputKeysForChild.push_back(indexKeyPart); + inputKeysForChild.push_back( + outputs.get(std::make_pair(PlanStageSlots::kKey, part.fieldPath->fullPath()))); } inputKeys.push_back(std::move(inputKeysForChild)); @@ -1631,32 +1429,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto sv = getSlotsToForward(childReqs, outputs); - // If the parent of 'root' has requested index keys, then we need to pass along our input - // keys as input values, as they will be part of the output 'root' provides its parent. - if (parentRequestsIdxKeys) { - for (auto&& inputKey : inputKeys.back()) { - sv.push_back(inputKey); - } - } inputVals.push_back(std::move(sv)); } PlanStageSlots outputs(childReqs, &_slotIdGenerator); - auto outputVals = getSlotsToForward(childReqs, outputs); - // If the parent of 'root' has requested index keys, then we need to generate output slots to - // hold the index keys that will be used as input to the parent of 'root'. - if (parentRequestsIdxKeys) { - auto idxKeySv = sbe::makeSV(); - for (int idx = 0; idx < mergeSortNode->sort.nFields(); ++idx) { - idxKeySv.emplace_back(_slotIdGenerator.generate()); - } - outputs.setIndexKeySlots(idxKeySv); - - for (auto keySlot : idxKeySv) { - outputVals.push_back(keySlot); - } - } + auto outputVals = getSlotsToForward(childReqs, outputs); auto stage = sbe::makeS<sbe::SortedMergeStage>(std::move(inputStages), std::move(inputKeys), @@ -1674,6 +1452,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder } } + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } @@ -1681,48 +1461,63 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildProjectionSimple(const QuerySolutionNode* root, const PlanStageReqs& reqs) { using namespace std::literals; - invariant(!reqs.getIndexKeyBitset()); + tassert(6023405, "buildProjectionSimple() does not support kKey", !reqs.hasKeys()); auto pn = static_cast<const ProjectionNodeSimple*>(root); - // The child must produce all of the slots required by the parent of this ProjectionNodeSimple. - // In addition to that, the child must always produce a 'resultSlot' because it's needed by the - // projection logic below. - auto childReqs = reqs.copy().set(kResult); - auto [inputStage, outputs] = build(pn->children[0].get(), childReqs); + auto [fields, additionalFields] = splitVector(reqs.getFields(), [&](const std::string& s) { + return pn->proj.type() == projection_ast::ProjectType::kInclusion + ? pn->proj.getRequiredFields().count(s) + : !pn->proj.getExcludedPaths().count(s); + }); - const auto childResult = outputs.get(kResult); + auto childReqs = reqs.copy().clearAllFields().setFields(std::move(fields)); + auto [stage, childOutputs] = build(pn->children[0].get(), childReqs); + auto outputs = std::move(childOutputs); - sbe::MakeBsonObjStage::FieldBehavior behaviour; - const OrderedPathSet* fields; - if (pn->proj.type() == projection_ast::ProjectType::kInclusion) { - behaviour = sbe::MakeBsonObjStage::FieldBehavior::keep; - fields = &pn->proj.getRequiredFields(); - } else { - behaviour = sbe::MakeBsonObjStage::FieldBehavior::drop; - fields = &pn->proj.getExcludedPaths(); - } - - outputs.set(kResult, _slotIdGenerator.generate()); - inputStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(inputStage), - outputs.get(kResult), - childResult, - behaviour, - *fields, - OrderedPathSet{}, - sbe::value::SlotVector{}, - true, - false, - root->nodeId()); + if (reqs.has(kResult)) { + const auto childResult = outputs.get(kResult); + + sbe::MakeBsonObjStage::FieldBehavior behaviour; + const OrderedPathSet* fields; + if (pn->proj.type() == projection_ast::ProjectType::kInclusion) { + behaviour = sbe::MakeBsonObjStage::FieldBehavior::keep; + fields = &pn->proj.getRequiredFields(); + } else { + behaviour = sbe::MakeBsonObjStage::FieldBehavior::drop; + fields = &pn->proj.getExcludedPaths(); + } + + outputs.set(kResult, _slotIdGenerator.generate()); + stage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(stage), + outputs.get(kResult), + childResult, + behaviour, + *fields, + OrderedPathSet{}, + sbe::value::SlotVector{}, + true, + false, + root->nodeId()); + } + + auto [outStage, nothingSlots] = projectNothingToSlots( + std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator); + for (size_t i = 0; i < additionalFields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])), + nothingSlots[i]); + } + + outputs.clearNonRequiredSlots(reqs); - return {std::move(inputStage), std::move(outputs)}; + return {std::move(outStage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildProjectionCovered(const QuerySolutionNode* root, const PlanStageReqs& reqs) { using namespace std::literals; - invariant(!reqs.getIndexKeyBitset()); + tassert(6023406, "buildProjectionCovered() does not support kKey", !reqs.hasKeys()); auto pn = static_cast<const ProjectionNodeCovered*>(root); invariant(pn->proj.isSimple()); @@ -1733,51 +1528,86 @@ SlotBasedStageBuilder::buildProjectionCovered(const QuerySolutionNode* root, !pn->children[0]->fetched()); // This is a ProjectionCoveredNode, so we will be pulling all the data we need from one index. - // Prepare a bitset to indicate which parts of the index key we need for the projection. - StringSet requiredFields = {pn->proj.getRequiredFields().begin(), - pn->proj.getRequiredFields().end()}; - - // The child must produce all of the slots required by the parent of this ProjectionNodeSimple, - // except for 'resultSlot' which will be produced by the MakeBsonObjStage below. In addition to - // that, the child must produce the index key slots that are needed by this covered projection. - // // pn->coveredKeyObj is the "index.keyPattern" from the child (which is either an IndexScanNode // or DistinctNode). pn->coveredKeyObj lists all the fields that the index can provide, not the - // fields that the projection wants. requiredFields lists all of the fields that the projection - // needs. Since this is a covered projection, we're guaranteed that pn->coveredKeyObj contains - // all of the fields that the projection needs. - auto childReqs = reqs.copy().clear(kResult); - - auto [indexKeyBitset, keyFieldNames] = - makeIndexKeyInclusionSet(pn->coveredKeyObj, requiredFields); - childReqs.getIndexKeyBitset() = std::move(indexKeyBitset); - - auto [inputStage, outputs] = build(pn->children[0].get(), childReqs); - - // Assert that the index scan produced index key slots for this covered projection. - auto indexKeySlots = *outputs.extractIndexKeySlots(); - - outputs.set(kResult, _slotIdGenerator.generate()); - inputStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(inputStage), - outputs.get(kResult), - boost::none, - boost::none, - std::vector<std::string>{}, - std::move(keyFieldNames), - std::move(indexKeySlots), - true, - false, - root->nodeId()); + // fields that the projection wants. 'pn->proj.getRequiredFields()' lists all of the fields + // that the projection needs. Since this is a simple covered projection, we're guaranteed that + // 'pn->proj.getRequiredFields()' is a subset of pn->coveredKeyObj. + + // List out the projected fields in the order they appear in 'coveredKeyObj'. + std::vector<std::string> keys; + StringDataSet keysSet; + for (auto&& elt : pn->coveredKeyObj) { + std::string key(elt.fieldNameStringData()); + if (pn->proj.getRequiredFields().count(key)) { + keys.emplace_back(std::move(key)); + keysSet.emplace(elt.fieldNameStringData()); + } + } + + // The child must produce all of the slots required by the parent of this ProjectionNodeSimple, + // except for 'resultSlot' which will be produced by the MakeBsonObjStage below if requested by + // the caller. In addition to that, the child must produce the index key slots that are needed + // by this covered projection. + auto childReqs = reqs.copy().clear(kResult).clearAllFields().setKeys(keys); + auto [stage, childOutputs] = build(pn->children[0].get(), childReqs); + auto outputs = std::move(childOutputs); + + if (reqs.has(kResult)) { + auto indexKeySlots = sbe::makeSV(); + std::vector<std::string> keyFieldNames; - return {std::move(inputStage), std::move(outputs)}; + if (keysSet.count("_id"_sd)) { + keyFieldNames.emplace_back("_id"_sd); + indexKeySlots.emplace_back(outputs.get(std::make_pair(PlanStageSlots::kKey, "_id"_sd))); + } + + for (const auto& key : keys) { + if (key != "_id"_sd) { + keyFieldNames.emplace_back(key); + indexKeySlots.emplace_back( + outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(key)))); + } + } + + auto resultSlot = _slotIdGenerator.generate(); + stage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(stage), + resultSlot, + boost::none, + boost::none, + std::vector<std::string>{}, + std::move(keyFieldNames), + std::move(indexKeySlots), + true, + false, + root->nodeId()); + + outputs.set(kResult, resultSlot); + } + + auto [fields, additionalFields] = + splitVector(reqs.getFields(), [&](const std::string& s) { return keysSet.count(s); }); + for (size_t i = 0; i < fields.size(); ++i) { + auto slot = outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(fields[i]))); + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slot); + } + + auto [outStage, nothingSlots] = projectNothingToSlots( + std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator); + for (size_t i = 0; i < additionalFields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])), + nothingSlots[i]); + } + + outputs.clearNonRequiredSlots(reqs); + + return {std::move(outStage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root, const PlanStageReqs& reqs) { - tassert(7055400, - "buildProjectionDefault() does not support index key bitsets", - !reqs.getIndexKeyBitset()); + tassert(6023407, "buildProjectionDefault() does not support kKey", !reqs.hasKeys()); auto pn = static_cast<const ProjectionNodeDefault*>(root); const auto& projection = pn->proj; @@ -1791,7 +1621,7 @@ SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root, // The child must produce all of the slots required by the parent of this ProjectionNodeDefault. // In addition to that, the child must always produce 'kResult' because it's needed by the // projection logic below. - auto childReqs = reqs.copy().set(kResult); + auto childReqs = reqs.copy().set(kResult).clearAllFields(); auto [stage, outputs] = build(pn->children[0].get(), childReqs); @@ -1805,8 +1635,10 @@ SlotBasedStageBuilder::buildProjectionDefault(const QuerySolutionNode* root, root->nodeId()); stage = resultStage.extractStage(root->nodeId()); - outputs.set(kResult, resultSlot); + + outputs.clearNonRequiredSlots(reqs); + return {std::move(stage), std::move(outputs)}; } @@ -1814,9 +1646,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildProjectionDefaultCovered(const QuerySolutionNode* root, const PlanStageReqs& reqs, const IndexScanNode* ixn) { - tassert(7055401, - "buildProjectionDefaultCovered() does not support index key bitsets", - !reqs.getIndexKeyBitset()); + tassert(6023408, "buildProjectionDefaultCovered() does not support kKey", !reqs.hasKeys()); auto pn = static_cast<const ProjectionNodeDefault*>(root); const auto& projection = pn->proj; @@ -1827,64 +1657,60 @@ SlotBasedStageBuilder::buildProjectionDefaultCovered(const QuerySolutionNode* ro tassert( 7055403, "buildProjectionDefaultCovered() expected 'pn' to not be fetched", !pn->fetched()); - // Convert projection fieldpaths into the tree of 'IndexKeyPatternTreeNode'. - IndexKeysBuilderContext context; - IndexKeysPreBuilder preVisitor{&context}; - IndexKeysInBuilder inVisitor{&context}; - IndexKeysPostBuilder postVisitor{&context}; - projection_ast::ProjectionASTConstWalker walker{&preVisitor, &inVisitor, &postVisitor}; - tree_walker::walk<true, projection_ast::ASTNode>(projection.root(), &walker); - - IndexKeyPatternTreeNode patternRoot = std::move(context.root); - - // Construct a bitset requesting slots from the underlying index scan. These slots - // correspond to index keys for projection fieldpaths. if (!ixn) { ixn = static_cast<const IndexScanNode*>(getLoneNodeByType(root, STAGE_IXSCAN)); } + auto& indexKeyPattern = ixn->index.keyPattern; - sbe::IndexKeysInclusionSet patternBitSet; + auto patternRoot = buildPatternTree(pn->proj); + std::vector<std::string> keys; + StringDataSet keysSet; std::vector<IndexKeyPatternTreeNode*> patternNodesForSlots; - size_t i = 0; for (const auto& element : indexKeyPattern) { sbe::MatchPath fieldRef{element.fieldNameStringData()}; // Projection field paths are always leaf nodes. In other words, projection like // {a: 1, 'a.b': 1} would produce a path collision error. if (auto node = patternRoot.findLeafNode(fieldRef); node) { - patternBitSet.set(i); + keys.emplace_back(element.fieldNameStringData()); + keysSet.emplace(element.fieldNameStringData()); patternNodesForSlots.push_back(node); } - - ++i; } - // We do not need index scan to restore the entire object. Instead, we will restore only - // necessary parts of it below. - auto childReqs = reqs.copy().clear(kResult); - childReqs.getIndexKeyBitset() = patternBitSet; + auto childReqs = reqs.copy().clear(kResult).clearAllFields().setKeys(keys); auto [stage, outputs] = build(pn->children[0].get(), childReqs); - auto indexKeySlots = *outputs.extractIndexKeySlots(); + auto [fields, additionalFields] = + splitVector(reqs.getFields(), [&](const std::string& s) { return keysSet.count(s); }); + for (size_t i = 0; i < fields.size(); ++i) { + auto slot = outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(fields[i]))); + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slot); + } - // Extract slots corresponding to each of the projection fieldpaths. - invariant(indexKeySlots.size() == patternNodesForSlots.size()); - for (size_t i = 0; i < indexKeySlots.size(); i++) { - patternNodesForSlots[i]->indexKeySlot = indexKeySlots[i]; + if (reqs.has(kResult) || !additionalFields.empty()) { + // Extract slots corresponding to each of the projection fieldpaths. + for (size_t i = 0; i < keys.size(); i++) { + patternNodesForSlots[i]->indexKeySlot = + outputs.get(std::make_pair(PlanStageSlots::kKey, StringData(keys[i]))); + } + + // Finally, build the expression to create object with requested projection fieldpaths. + auto resultSlot = _slotIdGenerator.generate(); + outputs.set(kResult, resultSlot); + + stage = sbe::makeProjectStage( + std::move(stage), root->nodeId(), resultSlot, buildNewObjExpr(&patternRoot)); } - // Finally, build the expression to create object with requested projection fieldpaths. - auto resultSlot = _slotIdGenerator.generate(); - stage = sbe::makeProjectStage( - std::move(stage), root->nodeId(), resultSlot, buildNewObjExpr(&patternRoot)); + outputs.clearNonRequiredSlots(reqs); - outputs.set(kResult, resultSlot); return {std::move(stage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildOr( const QuerySolutionNode* root, const PlanStageReqs& reqs) { - invariant(!reqs.getIndexKeyBitset()); + tassert(6023409, "buildOr() does not support kKey", !reqs.hasKeys()); sbe::PlanStage::Vector inputStages; std::vector<sbe::value::SlotVector> inputSlots; @@ -1941,7 +1767,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto textNode = static_cast<const TextMatchNode*>(root); const auto& coll = getCurrentCollection(reqs); tassert(5432212, "no collection object", coll); - tassert(5432213, "index keys requested for text match node", !reqs.getIndexKeyBitset()); + tassert(6023410, "buildTextMatch() does not support kKey", !reqs.hasKeys()); tassert(5432215, str::stream() << "text match node must have one child, but got " << root->children.size(), @@ -1992,7 +1818,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildReturnKey( const QuerySolutionNode* root, const PlanStageReqs& reqs) { - invariant(!reqs.getIndexKeyBitset()); + tassert(6023411, "buildReturnKey() does not support kKey", !reqs.hasKeys()); // TODO SERVER-49509: If the projection includes {$meta: "sortKey"}, the result of this stage // should also include the sort key. Everything else in the projection is ignored. @@ -2002,7 +1828,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // for 'resultSlot'. In addition to that, the child must always produce a 'returnKeySlot'. // After build() returns, we take the 'returnKeySlot' produced by the child and store it into // 'resultSlot' for the parent of this ReturnKeyNode to consume. - auto childReqs = reqs.copy().clear(kResult).set(kReturnKey); + auto childReqs = reqs.copy().clear(kResult).clearAllFields().set(kReturnKey); auto [stage, outputs] = build(returnKeyNode->children[0].get(), childReqs); outputs.set(kResult, outputs.get(kReturnKey)); @@ -2020,9 +1846,10 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder const QuerySolutionNode* root, const PlanStageReqs& reqs) { auto andHashNode = static_cast<const AndHashNode*>(root); + tassert(6023412, "buildAndHash() does not support kKey", !reqs.hasKeys()); tassert(5073711, "need at least two children for AND_HASH", andHashNode->children.size() >= 2); - auto childReqs = reqs.copy().set(kResult).set(kRecordId); + auto childReqs = reqs.copy().set(kResult).set(kRecordId).clearAllFields(); auto outerChild = andHashNode->children[0].get(); auto innerChild = andHashNode->children[1].get(); @@ -2049,13 +1876,13 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto collatorSlot = _data.env->getSlotIfExists("collator"_sd); // Designate outputs. - PlanStageSlots outputs(reqs, &_slotIdGenerator); + PlanStageSlots outputs; + + outputs.set(kResult, innerResultSlot); + if (reqs.has(kRecordId)) { outputs.set(kRecordId, innerIdSlot); } - if (reqs.has(kResult)) { - outputs.set(kResult, innerResultSlot); - } if (reqs.has(kSnapshotId) && innerSnapshotIdSlot) { auto slot = *innerSnapshotIdSlot; innerProjectSlots.push_back(slot); @@ -2071,26 +1898,25 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder innerProjectSlots.push_back(slot); outputs.set(kIndexKey, slot); } - if (reqs.has(kIndexKeyPattern) && innerIndexKeyPatternSlot) { auto slot = *innerIndexKeyPatternSlot; innerProjectSlots.push_back(slot); outputs.set(kIndexKeyPattern, slot); } - auto hashJoinStage = sbe::makeS<sbe::HashJoinStage>(std::move(outerStage), - std::move(innerStage), - outerCondSlots, - outerProjectSlots, - innerCondSlots, - innerProjectSlots, - collatorSlot, - root->nodeId()); + auto stage = sbe::makeS<sbe::HashJoinStage>(std::move(outerStage), + std::move(innerStage), + outerCondSlots, + outerProjectSlots, + innerCondSlots, + innerProjectSlots, + collatorSlot, + root->nodeId()); // If there are more than 2 children, iterate all remaining children and hash // join together. for (size_t i = 2; i < andHashNode->children.size(); i++) { - auto [stage, outputs] = build(andHashNode->children[i].get(), childReqs); + auto [childStage, outputs] = build(andHashNode->children[i].get(), childReqs); tassert(5073714, "outputs must contain kRecordId slot", outputs.has(kRecordId)); tassert(5073715, "outputs must contain kResult slot", outputs.has(kResult)); auto idSlot = outputs.get(kRecordId); @@ -2100,32 +1926,30 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // The previous HashJoinStage is always set as the inner stage, so that we can reuse the // innerIdSlot and innerResultSlot that have been designated as outputs. - hashJoinStage = sbe::makeS<sbe::HashJoinStage>(std::move(stage), - std::move(hashJoinStage), - condSlots, - projectSlots, - innerCondSlots, - innerProjectSlots, - collatorSlot, - root->nodeId()); - } - // Stop propagating the RecordId output if none of our ancestors are going to use it. - if (!reqs.has(kRecordId)) { - outputs.clear(kRecordId); + stage = sbe::makeS<sbe::HashJoinStage>(std::move(childStage), + std::move(stage), + condSlots, + projectSlots, + innerCondSlots, + innerProjectSlots, + collatorSlot, + root->nodeId()); } - return {std::move(hashJoinStage), std::move(outputs)}; + return {std::move(stage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildAndSorted( const QuerySolutionNode* root, const PlanStageReqs& reqs) { + tassert(6023413, "buildAndSorted() does not support kKey", !reqs.hasKeys()); + auto andSortedNode = static_cast<const AndSortedNode*>(root); // Need at least two children. tassert( 5073706, "need at least two children for AND_SORTED", andSortedNode->children.size() >= 2); - auto childReqs = reqs.copy().set(kResult).set(kRecordId); + auto childReqs = reqs.copy().set(kResult).set(kRecordId).clearAllFields(); auto outerChild = andSortedNode->children[0].get(); auto innerChild = andSortedNode->children[1].get(); @@ -2161,13 +1985,14 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto innerKeySlots = sbe::makeSV(innerIdSlot); auto innerProjectSlots = sbe::makeSV(innerResultSlot); - PlanStageSlots outputs(reqs, &_slotIdGenerator); + // Designate outputs. + PlanStageSlots outputs; + + outputs.set(kResult, innerResultSlot); + if (reqs.has(kRecordId)) { outputs.set(kRecordId, innerIdSlot); } - if (reqs.has(kResult)) { - outputs.set(kResult, innerResultSlot); - } if (reqs.has(kSnapshotId)) { auto innerSnapshotSlot = innerOutputs.get(kSnapshotId); innerProjectSlots.push_back(innerSnapshotSlot); @@ -2183,7 +2008,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder innerProjectSlots.push_back(innerIndexKeySlot); outputs.set(kIndexKey, innerIndexKeySlot); } - if (reqs.has(kIndexKeyPattern)) { auto innerIndexKeyPatternSlot = innerOutputs.get(kIndexKeyPattern); innerProjectSlots.push_back(innerIndexKeyPatternSlot); @@ -2193,19 +2017,19 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::vector<sbe::value::SortDirection> sortDirs(outerKeySlots.size(), sbe::value::SortDirection::Ascending); - auto mergeJoinStage = sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage), - std::move(innerStage), - outerKeySlots, - outerProjectSlots, - innerKeySlots, - innerProjectSlots, - sortDirs, - root->nodeId()); + auto stage = sbe::makeS<sbe::MergeJoinStage>(std::move(outerStage), + std::move(innerStage), + outerKeySlots, + outerProjectSlots, + innerKeySlots, + innerProjectSlots, + sortDirs, + root->nodeId()); // If there are more than 2 children, iterate all remaining children and merge // join together. for (size_t i = 2; i < andSortedNode->children.size(); i++) { - auto [stage, outputs] = build(andSortedNode->children[i].get(), childReqs); + auto [childStage, outputs] = build(andSortedNode->children[i].get(), childReqs); tassert(5073709, "outputs must contain kRecordId slot", outputs.has(kRecordId)); tassert(5073710, "outputs must contain kResult slot", outputs.has(kResult)); auto idSlot = outputs.get(kRecordId); @@ -2213,21 +2037,17 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder auto keySlots = sbe::makeSV(idSlot); auto projectSlots = sbe::makeSV(resultSlot); - mergeJoinStage = sbe::makeS<sbe::MergeJoinStage>(std::move(stage), - std::move(mergeJoinStage), - keySlots, - projectSlots, - innerKeySlots, - innerProjectSlots, - sortDirs, - root->nodeId()); - } - // Stop propagating the RecordId output if none of our ancestors are going to use it. - if (!reqs.has(kRecordId)) { - outputs.clear(kRecordId); + stage = sbe::makeS<sbe::MergeJoinStage>(std::move(childStage), + std::move(stage), + keySlots, + projectSlots, + innerKeySlots, + innerProjectSlots, + sortDirs, + root->nodeId()); } - return {std::move(mergeJoinStage), std::move(outputs)}; + return {std::move(stage), std::move(outputs)}; } namespace { @@ -2313,39 +2133,6 @@ void walkAndActOnFieldPaths(Expression* expr, const F& fn) { } /** - * Checks whether all field paths in 'idExpr' and all accumulator expressions are top-level ones. - */ -bool areAllFieldPathsOptimizable(const boost::intrusive_ptr<Expression>& idExpr, - const std::vector<AccumulationStatement>& accStmts) { - auto areFieldPathsOptimizable = true; - - auto checkFieldPath = [&](const ExpressionFieldPath* fieldExpr, int32_t nestedCondLevel) { - // We optimize neither a field path for the top-level document itself (getPathLength() == 1) - // nor a field path that refers to a variable. We can optimize only top-level fields - // (getPathLength() == 2). - // - // The 'nestedCondLevel' being > 0 means that a field path is refered to below conditional - // expressions at the parent $group node, when we cannot optimize field path access and - // therefore, cannot avoid materialization. - if (nestedCondLevel > 0 || fieldExpr->getFieldPath().getPathLength() != 2 || - fieldExpr->isVariableReference()) { - areFieldPathsOptimizable = false; - return; - } - }; - - // Checks field paths from the group-by expression. - walkAndActOnFieldPaths(idExpr.get(), checkFieldPath); - - // Checks field paths from the accumulator expressions. - for (auto&& accStmt : accStmts) { - walkAndActOnFieldPaths(accStmt.expr.argument.get(), checkFieldPath); - } - - return areFieldPathsOptimizable; -} - -/** * If there are adjacent $group stages in a pipeline and two $group stages are pushed down together, * the first $group becomes a child GROUP node and the second $group becomes a parent GROUP node in * a query solution tree. In the case that all field paths are top-level fields for the parent GROUP @@ -2362,10 +2149,7 @@ EvalStage optimizeFieldPaths(StageBuilderState& state, const PlanStageSlots& childOutputs, PlanNodeId nodeId) { using namespace fmt::literals; - auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult); - // Absent root slot means that we optimized away mkbson stage and so, we need to search - // top-level field names in child outputs. - auto searchInChildOutputs = !optionalRootSlot.has_value(); + auto rootSlot = childOutputs.getIfExists(PlanStageSlots::kResult); auto retEvalStage = std::move(childEvalStage); walkAndActOnFieldPaths(expr.get(), [&](const ExpressionFieldPath* fieldExpr, int32_t) { @@ -2376,24 +2160,10 @@ EvalStage optimizeFieldPaths(StageBuilderState& state, } auto fieldPathStr = fieldExpr->getFieldPath().fullPath(); - if (searchInChildOutputs) { - if (auto optionalFieldPathSlot = childOutputs.getIfExists(fieldPathStr); - optionalFieldPathSlot) { - state.preGeneratedExprs.emplace(fieldPathStr, *optionalFieldPathSlot); - } else { - // getField('fieldPathStr') on a slot containing a BSONObj would have produced - // 'Nothing' if mkbson stage removal optimization didn't occur. So, generates a - // 'Nothing' const expression to simulate such a result. - state.preGeneratedExprs.emplace( - fieldPathStr, sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0)); - } - - return; - } if (!state.preGeneratedExprs.contains(fieldPathStr)) { auto [curEvalExpr, curEvalStage] = generateExpression( - state, fieldExpr, std::move(retEvalStage), optionalRootSlot, nodeId); + state, fieldExpr, std::move(retEvalStage), rootSlot, nodeId, &childOutputs); auto [slot, stage] = projectEvalExpr( std::move(curEvalExpr), std::move(curEvalStage), nodeId, state.slotIdGenerator); @@ -2418,7 +2188,7 @@ std::pair<EvalExpr, EvalStage> generateGroupByKeyImpl( optimizeFieldPaths(state, idExpr, std::move(childEvalStage), childOutputs, nodeId); auto [groupByEvalExpr, groupByEvalStage] = stage_builder::generateExpression( - state, idExpr.get(), std::move(evalStage), optionalRootSlot, nodeId); + state, idExpr.get(), std::move(evalStage), optionalRootSlot, nodeId, &childOutputs); return {std::move(groupByEvalExpr), std::move(groupByEvalStage)}; } @@ -2430,7 +2200,7 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>> std::unique_ptr<sbe::PlanStage> childStage, PlanNodeId nodeId, sbe::value::SlotIdGenerator* slotIdGenerator) { - auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult); + auto optionalRootSlot = childOutputs.getIfExists(PlanStageSlots::kResult); EvalStage retEvalStage{std::move(childStage), optionalRootSlot ? sbe::value::SlotVector{*optionalRootSlot} : sbe::value::SlotVector{}}; @@ -2510,10 +2280,10 @@ std::tuple<sbe::value::SlotVector, EvalStage> generateAccumulator( PlanNodeId nodeId, sbe::value::SlotIdGenerator* slotIdGenerator, sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>>& accSlotToExprMap) { - // Input fields may need field traversal which ends up being a complex tree. + // Input fields may need field traversal. auto evalStage = optimizeFieldPaths( state, accStmt.expr.argument, std::move(childEvalStage), childOutputs, nodeId); - auto optionalRootSlot = childOutputs.getIfExists(SlotBasedStageBuilder::kResult); + auto optionalRootSlot = childOutputs.getIfExists(PlanStageSlots::kResult); auto [argExpr, accArgEvalStage] = stage_builder::buildArgument( state, accStmt, std::move(evalStage), optionalRootSlot, nodeId); @@ -2642,6 +2412,7 @@ sbe::value::SlotVector dedupGroupBySlots(const sbe::value::SlotVector& groupBySl std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildGroup( const QuerySolutionNode* root, const PlanStageReqs& reqs) { using namespace fmt::literals; + tassert(6023414, "buildGroup() does not support kKey", !reqs.hasKeys()); auto groupNode = static_cast<const GroupNode*>(root); auto nodeId = groupNode->nodeId(); @@ -2657,14 +2428,20 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder const auto& childNode = groupNode->children[0].get(); const auto& accStmts = groupNode->accumulators; - auto childStageType = childNode->getType(); - auto childReqs = reqs.copy(); - if (childStageType == StageType::STAGE_GROUP && areAllFieldPathsOptimizable(idExpr, accStmts)) { - // Does not ask the GROUP child for the result slot to avoid unnecessary materialization if - // all fields are top-level fields. See the end of this function. For example, GROUP - GROUP - // - COLLSCAN case. + auto childReqs = reqs.copy().set(kResult).clearAllFields(); + + // Don't ask the GROUP child for the result slot to avoid unnecessary materialization if it's + // possible to get everything we need from top-level field slots. + if (childNode->getType() == StageType::STAGE_GROUP && !groupNode->needWholeDocument && + !groupNode->needsAnyMetadata) { childReqs.clear(kResult); + + for (auto&& pathStr : groupNode->requiredFields) { + auto path = sbe::MatchPath{pathStr}; + const auto& topLevelField = path.getPart(0); + childReqs.set(std::make_pair(PlanStageSlots::kField, topLevelField)); + } } // Builds the child and gets the child result slot. @@ -2733,6 +2510,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder aggSlotsVec, nodeId, &_slotIdGenerator); + auto stage = groupFinalEvalStage.extractStage(nodeId); tassert(5851605, "The number of final slots must be as 1 (the final group-by slot) + the number of acc " @@ -2742,21 +2520,37 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // Cleans up optimized expressions. _state.preGeneratedExprs.clear(); + auto fieldNamesSet = StringDataSet{fieldNames.begin(), fieldNames.end()}; + auto [fields, additionalFields] = + splitVector(reqs.getFields(), [&](const std::string& s) { return fieldNamesSet.count(s); }); + auto fieldsSet = StringDataSet{fields.begin(), fields.end()}; + PlanStageSlots outputs; - std::unique_ptr<sbe::PlanStage> outStage; - // Builds a stage to create a result object out of a group-by slot and gathered accumulator - // result slots if the parent node requests so. Otherwise, returns field names and associated - // slots so that a parent stage above can directly refer to a slot by its name because there's - // no returned object. + for (size_t i = 0; i < fieldNames.size(); ++i) { + if (fieldsSet.count(fieldNames[i])) { + outputs.set(std::make_pair(PlanStageSlots::kField, fieldNames[i]), finalSlots[i]); + } + }; + + auto [outStage, nothingSlots] = projectNothingToSlots( + std::move(stage), additionalFields.size(), root->nodeId(), &_slotIdGenerator); + for (size_t i = 0; i < additionalFields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(additionalFields[i])), + nothingSlots[i]); + } + + // Builds a outStage to create a result object out of a group-by slot and gathered accumulator + // result slots if the parent node requests so. if (reqs.has(kResult)) { - outputs.set(kResult, _slotIdGenerator.generate()); + auto resultSlot = _slotIdGenerator.generate(); + outputs.set(kResult, resultSlot); // This mkbson stage combines 'finalSlots' into a bsonObject result slot which has // 'fieldNames' fields. if (groupNode->shouldProduceBson) { - outStage = sbe::makeS<sbe::MakeBsonObjStage>(groupFinalEvalStage.extractStage(nodeId), - outputs.get(kResult), // objSlot - boost::none, // rootSlot - boost::none, // fieldBehavior + outStage = sbe::makeS<sbe::MakeBsonObjStage>(std::move(outStage), + resultSlot, // objSlot + boost::none, // rootSlot + boost::none, // fieldBehavior std::vector<std::string>{}, // fields std::move(fieldNames), // projectFields std::move(finalSlots), // projectVars @@ -2764,8 +2558,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder false, // returnOldObject nodeId); } else { - outStage = sbe::makeS<sbe::MakeObjStage>(groupFinalEvalStage.extractStage(nodeId), - outputs.get(kResult), // objSlot + outStage = sbe::makeS<sbe::MakeObjStage>(std::move(outStage), + resultSlot, // objSlot boost::none, // rootSlot boost::none, // fieldBehavior std::vector<std::string>{}, // fields @@ -2775,12 +2569,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder false, // returnOldObject nodeId); } - } else { - for (size_t i = 0; i < finalSlots.size(); ++i) { - outputs.set("CURRENT." + fieldNames[i], finalSlots[i]); - }; - - outStage = groupFinalEvalStage.extractStage(nodeId); } return {std::move(outStage), std::move(outputs)}; @@ -2790,7 +2578,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* root, const PlanStageReqs& reqs) { using namespace std::literals; - invariant(!reqs.getIndexKeyBitset()); + tassert(6023415, "makeUnionForTailableCollScan() does not support kKey", !reqs.hasKeys()); // Register a SlotId in the global environment which would contain a recordId to resume a // tailable collection scan from. A PlanStage executor will track the last seen recordId and @@ -2814,7 +2602,7 @@ SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* roo childReqs.setIsTailableCollScanResumeBranch(isTailableCollScanResumeBranch); auto [branch, outputs] = build(root, childReqs); - auto branchSlots = getSlotsToForward(reqs, outputs); + auto branchSlots = getSlotsToForward(childReqs, outputs); return {std::move(branchSlots), std::move(branch)}; }; @@ -2877,57 +2665,50 @@ auto buildShardFilterGivenShardKeySlot(sbe::value::SlotId shardKeySlot, } // namespace std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> -SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterNode, - sbe::value::SlotId shardFiltererSlot, - BSONObj shardKeyPattern, - BSONObj indexKeyPattern, - const QuerySolutionNode* child, - PlanStageReqs childReqs) { - StringDataSet shardKeyFields; - for (auto&& shardKeyElt : shardKeyPattern) { - shardKeyFields.insert(shardKeyElt.fieldNameStringData()); - } +SlotBasedStageBuilder::buildShardFilterCovered(const QuerySolutionNode* root, + const PlanStageReqs& reqs) { + // Constructs an optimized SBE plan for 'filterNode' in the case that the fields of the + // 'shardKeyPattern' are provided by 'child'. In this case, the SBE tree for 'child' will + // fill out slots for the necessary components of the index key. These slots can be read + // directly in order to determine the shard key that should be passed to the + // 'shardFiltererSlot'. + const auto filterNode = static_cast<const ShardingFilterNode*>(root); + auto child = filterNode->children[0].get(); + tassert(6023416, + "buildShardFilterCovered() expects ixscan below shard filter", + child->getType() == STAGE_IXSCAN || child->getType() == STAGE_VIRTUAL_SCAN); - // Save the bit vector describing the fields from the index that our parent requires. The shard - // filtering process may require additional fields that are not needed by the parent (for - // example, if the parent is projecting field "a" but the shard key is {a: 1, b: 1}). We will - // need the parent's reqs later on so that we can hand the correct slot vector for these fields - // back to our parent. - auto parentIndexKeyReqs = childReqs.getIndexKeyBitset(); + // Extract the child's key pattern. + BSONObj indexKeyPattern = child->getType() == STAGE_IXSCAN + ? static_cast<const IndexScanNode*>(child)->index.keyPattern + : static_cast<const VirtualScanNode*>(child)->indexKeyPattern; - // Determine the set of fields from the index required to obtain the shard key and union those - // with the set of fields from the index required by the parent stage. - auto [shardKeyIndexReqs, _] = makeIndexKeyInclusionSet(indexKeyPattern, shardKeyFields); - const auto ixKeyBitset = - parentIndexKeyReqs.value_or(sbe::IndexKeysInclusionSet{}) | shardKeyIndexReqs; - childReqs.getIndexKeyBitset() = ixKeyBitset; + auto childReqs = reqs.copy(); + + // If we're sharded make sure that we don't return data that isn't owned by the shard. This + // situation can occur when pending documents from in-progress migrations are inserted and when + // there are orphaned documents from aborted migrations. To check if the document is owned by + // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a + // BSONObj. + auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern(); + // We register the "shardFilterer" slot but we don't construct the ShardFilterer here, because + // once constructed the ShardFilterer will prevent orphaned documents from being deleted. We + // will construct the ShardFilterer later while preparing the SBE tree for execution. + auto shardFiltererSlot = _data.env->registerSlot( + "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator); + + for (auto&& shardKeyElt : shardKeyPattern) { + childReqs.set(std::make_pair(PlanStageSlots::kKey, shardKeyElt.fieldNameStringData())); + } auto [stage, outputs] = build(child, childReqs); - tassert(5562302, "Expected child to produce index key slots", outputs.getIndexKeySlots()); - - // Maps from key name -> (index in outputs.getIndexKeySlots(), is hashed). - auto ixKeyPatternFieldToSlotIdx = [&, childOutputs = std::ref(outputs)]() { - StringDataMap<std::pair<sbe::value::SlotId, bool>> ret; - - // Keeps track of which component we're reading in the index key pattern. - size_t i = 0; - // Keeps track of the index we are in the slot vector produced by the ix scan. The slot - // vector produced by the ix scan may be a subset of the key pattern. - size_t slotIdx = 0; - for (auto&& ixPatternElt : indexKeyPattern) { - if (shardKeyFields.count(ixPatternElt.fieldNameStringData())) { - const bool isHashed = ixPatternElt.valueStringData() == IndexNames::HASHED; - const auto slotId = (*childOutputs.get().getIndexKeySlots())[slotIdx]; - ret.emplace(ixPatternElt.fieldNameStringData(), std::make_pair(slotId, isHashed)); - } - if (ixKeyBitset[i]) { - ++slotIdx; - } - ++i; - } - return ret; - }(); + // Maps from key name to a bool that indicates whether the key is hashed. + StringDataMap<bool> indexKeyPatternMap; + for (auto&& ixPatternElt : indexKeyPattern) { + indexKeyPatternMap.emplace(ixPatternElt.fieldNameStringData(), + ShardKeyPattern::isHashedPatternEl(ixPatternElt)); + } // Build a project stage to deal with hashed shard keys. This step *could* be skipped if we're // dealing with non-hashed sharding, but it's done this way for sake of simplicity. @@ -2935,33 +2716,28 @@ SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterN sbe::value::SlotVector fieldSlots; std::vector<std::string> projectFields; for (auto&& shardKeyPatternElt : shardKeyPattern) { - auto it = ixKeyPatternFieldToSlotIdx.find(shardKeyPatternElt.fieldNameStringData()); - tassert(5562303, "Could not find element", it != ixKeyPatternFieldToSlotIdx.end()); - auto [slotId, ixKeyEltHashed] = it->second; + auto it = indexKeyPatternMap.find(shardKeyPatternElt.fieldNameStringData()); + tassert(5562303, "Could not find element", it != indexKeyPatternMap.end()); + const auto ixKeyEltHashed = it->second; + const auto slotId = outputs.get( + std::make_pair(PlanStageSlots::kKey, shardKeyPatternElt.fieldNameStringData())); // Get the value stored in the index for this component of the shard key. We may have to // hash it. auto elem = makeVariable(slotId); + // Handle the case where the index key or shard key is hashed. const bool shardKeyEltHashed = ShardKeyPattern::isHashedPatternEl(shardKeyPatternElt); - - if (shardKeyEltHashed) { - if (ixKeyEltHashed) { - // (1) The index stores hashed data and the shard key field is hashed. - // Nothing to do here. We can apply shard filtering with no other changes. - } else { - // (2) The shard key field is hashed but the index stores unhashed data. We must - // apply the hash function before passing this off to the shard filter. - elem = makeFunction("shardHash"_sd, std::move(elem)); - } - } else { - if (ixKeyEltHashed) { - // (3) The index stores hashed data but the shard key is not hashed. This is a bug. - MONGO_UNREACHABLE_TASSERT(5562300); - } else { - // (4) The shard key field is not hashed, and the index does not store hashed data. - // Again, we do nothing here. - } + if (ixKeyEltHashed) { + // If the index stores hashed data, then we know the shard key field is hashed as + // well. Nothing to do here. We can apply shard filtering with no other changes. + tassert(6023421, + "Index key is hashed, expected corresponding shard key to be hashed", + shardKeyEltHashed); + } else if (shardKeyEltHashed) { + // The shard key field is hashed but the index stores unhashed data. We must apply + // the hash function before passing this off to the shard filter. + elem = makeFunction("shardHash"_sd, std::move(elem)); } fieldSlots.push_back(_slotIdGenerator.generate()); @@ -2987,46 +2763,17 @@ SlotBasedStageBuilder::buildShardFilterCovered(const ShardingFilterNode* filterN auto filterStage = buildShardFilterGivenShardKeySlot( shardKeySlot, std::move(mkObjStage), shardFiltererSlot, filterNode->nodeId()); - outputs.setIndexKeySlots(!parentIndexKeyReqs ? boost::none - : boost::optional<sbe::value::SlotVector>{ - makeIndexKeyOutputSlotsMatchingParentReqs( - indexKeyPattern, - *parentIndexKeyReqs, - *childReqs.getIndexKeyBitset(), - *outputs.getIndexKeySlots())}); + outputs.clearNonRequiredSlots(reqs); return {std::move(filterStage), std::move(outputs)}; } std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildShardFilter( const QuerySolutionNode* root, const PlanStageReqs& reqs) { - const auto filterNode = static_cast<const ShardingFilterNode*>(root); - - // If we're sharded make sure that we don't return data that isn't owned by the shard. This - // situation can occur when pending documents from in-progress migrations are inserted and when - // there are orphaned documents from aborted migrations. To check if the document is owned by - // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a - // BSONObj. - auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern(); - // We register the "shardFilterer" slot but not construct the ShardFilterer here is because once - // constructed the ShardFilterer will prevent orphaned documents from being deleted. We will - // construct the 'ShardFiltered' later while preparing the SBE tree for execution. - auto shardFiltererSlot = _data.env->registerSlot( - "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator); - - // Determine if our child is an index scan and extract it's key pattern, or empty BSONObj if our - // child is not an IXSCAN node. - BSONObj indexKeyPattern = [&]() { - auto childNode = filterNode->children[0].get(); - switch (childNode->getType()) { - case StageType::STAGE_IXSCAN: - return static_cast<const IndexScanNode*>(childNode)->index.keyPattern; - case StageType::STAGE_VIRTUAL_SCAN: - return static_cast<const VirtualScanNode*>(childNode)->indexKeyPattern; - default: - return BSONObj{}; - } - }(); + auto child = root->children[0].get(); + bool childIsIndexScan = child->getType() == STAGE_IXSCAN || + (child->getType() == STAGE_VIRTUAL_SCAN && + !static_cast<const VirtualScanNode*>(child)->indexKeyPattern.isEmpty()); // If we're not required to fill out the 'kResult' slot, then instead we can request a slot from // the child for each of the fields which constitute the shard key. This allows us to avoid @@ -3036,17 +2783,25 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder // We only apply this optimization in the special case that the child QSN is an IXSCAN, since in // this case we can request exactly the fields we need according to their position in the index // key pattern. - auto childReqs = reqs.copy().setIf(kResult, indexKeyPattern.isEmpty()); - if (!childReqs.has(kResult)) { - return buildShardFilterCovered(filterNode, - shardFiltererSlot, - shardKeyPattern, - std::move(indexKeyPattern), - filterNode->children[0].get(), - std::move(childReqs)); + if (!reqs.has(kResult) && childIsIndexScan) { + return buildShardFilterCovered(root, reqs); } - auto [stage, outputs] = build(filterNode->children[0].get(), childReqs); + auto childReqs = reqs.copy().set(kResult); + + // If we're sharded make sure that we don't return data that isn't owned by the shard. This + // situation can occur when pending documents from in-progress migrations are inserted and when + // there are orphaned documents from aborted migrations. To check if the document is owned by + // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a + // BSONObj. + auto shardKeyPattern = _collections.getMainCollection().getShardKeyPattern(); + // We register the "shardFilterer" slot but we don't construct the ShardFilterer here, because + // once constructed the ShardFilterer will prevent orphaned documents from being deleted. We + // will construct the ShardFilterer later while preparing the SBE tree for execution. + auto shardFiltererSlot = _data.env->registerSlot( + "shardFilterer"_sd, sbe::value::TypeTags::Nothing, 0, false, &_slotIdGenerator); + + auto [stage, outputs] = build(child, childReqs); // Build an expression to extract the shard key from the document based on the shard key // pattern. To do this, we iterate over the shard key pattern parts and build nested 'getField' @@ -3136,7 +2891,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder {STAGE_SKIP, &SlotBasedStageBuilder::buildSkip}, {STAGE_SORT_SIMPLE, &SlotBasedStageBuilder::buildSort}, {STAGE_SORT_DEFAULT, &SlotBasedStageBuilder::buildSort}, - {STAGE_SORT_KEY_GENERATOR, &SlotBasedStageBuilder::buildSortKeyGeneraror}, + {STAGE_SORT_KEY_GENERATOR, &SlotBasedStageBuilder::buildSortKeyGenerator}, {STAGE_PROJECTION_SIMPLE, &SlotBasedStageBuilder::buildProjectionSimple}, {STAGE_PROJECTION_DEFAULT, &SlotBasedStageBuilder::buildProjectionDefault}, {STAGE_PROJECTION_COVERED, &SlotBasedStageBuilder::buildProjectionCovered}, @@ -3178,6 +2933,29 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder break; } - return std::invoke(kStageBuilders.at(root->getType()), *this, root, reqs); + auto [stage, slots] = std::invoke(kStageBuilders.at(root->getType()), *this, root, reqs); + auto outputs = std::move(slots); + + auto fields = filterVector(reqs.getFields(), [&](const std::string& s) { + return !outputs.has(std::make_pair(PlanStageSlots::kField, StringData(s))); + }); + + if (!fields.empty()) { + tassert(6023424, + str::stream() << "Expected build() for " << stageTypeToString(root->getType()) + << " to either produce a kResult slot or to satisfy all kField reqs", + outputs.has(PlanStageSlots::kResult)); + + auto resultSlot = outputs.get(PlanStageSlots::kResult); + auto [outStage, slots] = projectTopLevelFields( + std::move(stage), fields, resultSlot, root->nodeId(), &_slotIdGenerator); + stage = std::move(outStage); + + for (size_t i = 0; i < fields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, std::move(fields[i])), slots[i]); + } + } + + return {std::move(stage), std::move(outputs)}; } } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder.h b/src/mongo/db/query/sbe_stage_builder.h index 14bffbc1aa3..081660c43b1 100644 --- a/src/mongo/db/query/sbe_stage_builder.h +++ b/src/mongo/db/query/sbe_stage_builder.h @@ -40,6 +40,7 @@ #include "mongo/db/query/sbe_stage_builder_helpers.h" #include "mongo/db/query/shard_filterer_factory_interface.h" #include "mongo/db/query/stage_builder.h" +#include "mongo/util/pair_map.h" namespace mongo::stage_builder { /** @@ -51,6 +52,12 @@ std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment( OperationContext* opCtx, sbe::value::SlotIdGenerator* slotIdGenerator); +class PlanStageReqs; +class PlanStageSlots; +sbe::value::SlotVector getSlotsToForward(const PlanStageReqs& reqs, + const PlanStageSlots& outputs, + const sbe::value::SlotVector& exclude = sbe::makeSV()); + /** * This function prepares the SBE tree for execution, such as attaching the OperationContext, * ensuring that the SBE tree is registered with the PlanYieldPolicySBE and populating the @@ -105,78 +112,86 @@ struct ParameterizedIndexScanSlots { */ class PlanStageSlots { public: - static constexpr StringData kResult = "result"_sd; - static constexpr StringData kRecordId = "recordId"_sd; - static constexpr StringData kReturnKey = "returnKey"_sd; - static constexpr StringData kSnapshotId = "snapshotId"_sd; - static constexpr StringData kIndexId = "indexId"_sd; - static constexpr StringData kIndexKey = "indexKey"_sd; - static constexpr StringData kIndexKeyPattern = "indexKeyPattern"_sd; + // The _slots map is capable of holding different "classes" of slots: + // 1) kMeta slots are used to hold the current document (kResult), record ID (kRecordId), and + // various pieces of metadata. + // 2) kField slots represent the values of top-level fields, or in some cases of dotted field + // paths (when we are getting the dotted field from a non-multikey index and we know no array + // traversal is needed). These slots hold the actual values of the fields / field paths (not + // the sort key or collation comparison key for the field). + // 3) kKey slots represent the raw key value that comes from an ixscan / ixseek stage for a + // given field path. This raw key value can be used for sorting / comparison, but it is not + // always equal to the actual value of the field path (for example, if the key is coming from + // an index that has a non-simple collation). + enum class Type { + kMeta, + kField, + kKey, + }; + + using Name = std::pair<Type, StringData>; + using OwnedName = std::pair<Type, std::string>; + + static constexpr auto kField = Type::kField; + static constexpr auto kKey = Type::kKey; + static constexpr auto kMeta = Type::kMeta; + + static constexpr Name kResult = {kMeta, "result"_sd}; + static constexpr Name kRecordId = {kMeta, "recordId"_sd}; + static constexpr Name kReturnKey = {kMeta, "returnKey"_sd}; + static constexpr Name kSnapshotId = {kMeta, "snapshotId"_sd}; + static constexpr Name kIndexId = {kMeta, "indexId"_sd}; + static constexpr Name kIndexKey = {kMeta, "indexKey"_sd}; + static constexpr Name kIndexKeyPattern = {kMeta, "indexKeyPattern"_sd}; PlanStageSlots() = default; PlanStageSlots(const PlanStageReqs& reqs, sbe::value::SlotIdGenerator* slotIdGenerator); - bool has(StringData str) const { + bool has(const Name& str) const { return _slots.count(str); } - sbe::value::SlotId get(StringData str) const { + sbe::value::SlotId get(const Name& str) const { auto it = _slots.find(str); invariant(it != _slots.end()); return it->second; } - boost::optional<sbe::value::SlotId> getIfExists(StringData str) const { + boost::optional<sbe::value::SlotId> getIfExists(const Name& str) const { if (auto it = _slots.find(str); it != _slots.end()) { return it->second; } return boost::none; } - void set(StringData str, sbe::value::SlotId slot) { - _slots[str] = slot; - } - - void clear(StringData str) { - _slots.erase(str); + void set(const Name& str, sbe::value::SlotId slot) { + _slots.insert_or_assign(str, slot); } - const boost::optional<sbe::value::SlotVector>& getIndexKeySlots() const { - return _indexKeySlots; + void set(OwnedName str, sbe::value::SlotId slot) { + _slots.insert_or_assign(std::move(str), slot); } - boost::optional<sbe::value::SlotVector> extractIndexKeySlots() { - ON_BLOCK_EXIT([this] { _indexKeySlots = boost::none; }); - return std::move(_indexKeySlots); - } - - void setIndexKeySlots(sbe::value::SlotVector iks) { - _indexKeySlots = std::move(iks); - } - - void setIndexKeySlots(boost::optional<sbe::value::SlotVector> iks) { - _indexKeySlots = std::move(iks); + void clear(const Name& str) { + _slots.erase(str); } /** - * This method applies an action to some/all of the slots within this struct (excluding index - * key slots). For each slot in this struct, the action is will be applied to the slot if (and - * only if) the corresponding flag in 'reqs' is true. + * This method applies an action to some/all of the slots within this struct. For each slot in + * this struct, the action is will be applied to the slot if (and only if) the corresponding + * flag in 'reqs' is true. */ inline void forEachSlot(const PlanStageReqs& reqs, const std::function<void(sbe::value::SlotId)>& fn) const; - inline void forEachSlot( - const PlanStageReqs& reqs, - const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const; + inline void forEachSlot(const PlanStageReqs& reqs, + const std::function<void(sbe::value::SlotId, const Name&)>& fn) const; -private: - StringMap<sbe::value::SlotId> _slots; + inline void clearNonRequiredSlots(const PlanStageReqs& reqs); - // When an index scan produces parts of an index key for a covered plan, this is where the - // slots for the produced values are stored. - boost::optional<sbe::value::SlotVector> _indexKeySlots; +private: + PairMap<Type, std::string, sbe::value::SlotId> _slots; }; /** @@ -185,38 +200,71 @@ private: */ class PlanStageReqs { public: + using Type = PlanStageSlots::Type; + using Name = PlanStageSlots::Name; + using OwnedName = std::pair<Type, std::string>; + + static constexpr auto kField = PlanStageSlots::Type::kField; + static constexpr auto kKey = PlanStageSlots::Type::kKey; + static constexpr auto kMeta = PlanStageSlots::Type::kMeta; + PlanStageReqs copy() const { return *this; } - bool has(StringData str) const { + bool has(const Name& str) const { auto it = _slots.find(str); return it != _slots.end() && it->second; } - PlanStageReqs& set(StringData str) { - _slots[str] = true; + PlanStageReqs& set(const Name& str) { + _slots.insert_or_assign(str, true); + return *this; + } + + PlanStageReqs& set(OwnedName str) { + _slots.insert_or_assign(std::move(str), true); + return *this; + } + + PlanStageReqs& set(const std::vector<Name>& strs) { + for (size_t i = 0; i < strs.size(); ++i) { + _slots.insert_or_assign(strs[i], true); + } + return *this; + } + + PlanStageReqs& set(std::vector<OwnedName> strs) { + for (size_t i = 0; i < strs.size(); ++i) { + _slots.insert_or_assign(std::move(strs[i]), true); + } return *this; } - PlanStageReqs& setIf(StringData str, bool condition) { + PlanStageReqs& setIf(const Name& str, bool condition) { if (condition) { - _slots[str] = true; + _slots.insert_or_assign(str, true); } return *this; } - PlanStageReqs& clear(StringData str) { - _slots.erase(str); + PlanStageReqs& setFields(std::vector<std::string> strs) { + for (size_t i = 0; i < strs.size(); ++i) { + _slots.insert_or_assign(std::make_pair(kField, std::move(strs[i])), true); + } return *this; } - boost::optional<sbe::IndexKeysInclusionSet>& getIndexKeyBitset() { - return _indexKeyBitset; + PlanStageReqs& setKeys(std::vector<std::string> strs) { + for (size_t i = 0; i < strs.size(); ++i) { + _slots.insert_or_assign(std::make_pair(kKey, std::move(strs[i])), true); + } + return *this; } - const boost::optional<sbe::IndexKeysInclusionSet>& getIndexKeyBitset() const { - return _indexKeyBitset; + PlanStageReqs& clear(const Name& str) { + _slots.erase(str); + return *this; } bool getIsBuildingUnionForTailableCollScan() const { @@ -243,6 +291,52 @@ public: return _targetNamespace; } + bool hasType(Type t) const { + for (auto&& [name, isRequired] : _slots) { + if (isRequired && name.first == t) { + return true; + } + } + return false; + } + bool hasFields() const { + return hasType(kField); + } + bool hasKeys() const { + return hasType(kKey); + } + + std::vector<std::string> getOfType(Type t) const { + std::vector<std::string> res; + for (auto&& [name, isRequired] : _slots) { + if (isRequired && name.first == t) { + res.push_back(name.second); + } + } + std::sort(res.begin(), res.end()); + return res; + } + std::vector<std::string> getFields() const { + return getOfType(kField); + } + std::vector<std::string> getKeys() const { + return getOfType(kKey); + } + + PlanStageReqs& clearAllOfType(Type t) { + auto fields = getOfType(t); + for (const auto& field : fields) { + _slots.erase(std::make_pair(kField, StringData(field))); + } + return *this; + } + PlanStageReqs& clearAllFields() { + return clearAllOfType(kField); + } + PlanStageReqs& clearAllKeys() { + return clearAllOfType(kKey); + } + friend PlanStageSlots::PlanStageSlots(const PlanStageReqs& reqs, sbe::value::SlotIdGenerator* slotIdGenerator); @@ -251,14 +345,12 @@ public: friend void PlanStageSlots::forEachSlot( const PlanStageReqs& reqs, - const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const; + const std::function<void(sbe::value::SlotId, const Name&)>& fn) const; -private: - StringMap<bool> _slots; + friend void PlanStageSlots::clearNonRequiredSlots(const PlanStageReqs& reqs); - // A bitset here indicates that we have a covered projection that is expecting to parts of the - // index key from an index scan. - boost::optional<sbe::IndexKeysInclusionSet> _indexKeyBitset; +private: + PairMap<Type, std::string, bool> _slots; // When we're in the middle of building a special union sub-tree implementing a tailable cursor // collection scan, this flag will be set to true. Otherwise this flag will be false. @@ -283,11 +375,11 @@ void PlanStageSlots::forEachSlot(const PlanStageReqs& reqs, // Clang raises an error if we attempt to use 'name' in the tassert() below, because // tassert() is a macro that uses lambdas and 'name' is defined via "local binding". // We work-around this by copying 'name' to a local variable 'slotName'. - auto slotName = StringData(name); + auto slotName = Name(name); auto it = _slots.find(slotName); tassert(7050900, - str::stream() << "Could not find '" << slotName - << "' slot in the map, expected slot to exist", + str::stream() << "Could not find " << static_cast<int>(slotName.first) << ":'" + << slotName.second << "' in the slot map, expected slot to exist", it != _slots.end()); fn(it->second); @@ -297,17 +389,17 @@ void PlanStageSlots::forEachSlot(const PlanStageReqs& reqs, void PlanStageSlots::forEachSlot( const PlanStageReqs& reqs, - const std::function<void(sbe::value::SlotId, const StringData&)>& fn) const { + const std::function<void(sbe::value::SlotId, const Name&)>& fn) const { for (auto&& [name, isRequired] : reqs._slots) { if (isRequired) { // Clang raises an error if we attempt to use 'name' in the tassert() below, because // tassert() is a macro that uses lambdas and 'name' is defined via "local binding". // We work-around this by copying 'name' to a local variable 'slotName'. - auto slotName = StringData(name); + auto slotName = Name(name); auto it = _slots.find(slotName); tassert(7050901, - str::stream() << "Could not find '" << slotName - << "' slot in the map, expected slot to exist", + str::stream() << "Could not find " << static_cast<int>(slotName.first) << ":'" + << slotName.second << "' in the slot map, expected slot to exist", it != _slots.end()); fn(it->second, slotName); @@ -315,6 +407,21 @@ void PlanStageSlots::forEachSlot( } } +void PlanStageSlots::clearNonRequiredSlots(const PlanStageReqs& reqs) { + auto it = _slots.begin(); + while (it != _slots.end()) { + auto& name = it->first; + auto reqIt = reqs._slots.find(name); + // We never clear kResult, regardless of whether it is required by 'reqs'. + if ((reqIt != reqs._slots.end() && reqIt->second) || + (name.first == kResult.first && name.second == kResult.second)) { + ++it; + } else { + _slots.erase(it++); + } + } +} + using InputParamToSlotMap = stdx::unordered_map<MatchExpression::InputParamId, sbe::value::SlotId>; using VariableIdToSlotMap = stdx::unordered_map<Variables::Id, sbe::value::SlotId>; @@ -443,13 +550,13 @@ private: */ class SlotBasedStageBuilder final : public StageBuilder<sbe::PlanStage> { public: - static constexpr StringData kResult = PlanStageSlots::kResult; - static constexpr StringData kRecordId = PlanStageSlots::kRecordId; - static constexpr StringData kReturnKey = PlanStageSlots::kReturnKey; - static constexpr StringData kSnapshotId = PlanStageSlots::kSnapshotId; - static constexpr StringData kIndexId = PlanStageSlots::kIndexId; - static constexpr StringData kIndexKey = PlanStageSlots::kIndexKey; - static constexpr StringData kIndexKeyPattern = PlanStageSlots::kIndexKeyPattern; + static constexpr auto kResult = PlanStageSlots::kResult; + static constexpr auto kRecordId = PlanStageSlots::kRecordId; + static constexpr auto kReturnKey = PlanStageSlots::kReturnKey; + static constexpr auto kSnapshotId = PlanStageSlots::kSnapshotId; + static constexpr auto kIndexId = PlanStageSlots::kIndexId; + static constexpr auto kIndexKey = PlanStageSlots::kIndexKey; + static constexpr auto kIndexKeyPattern = PlanStageSlots::kIndexKeyPattern; SlotBasedStageBuilder(OperationContext* opCtx, const MultipleCollectionAccessor& collections, @@ -457,6 +564,12 @@ public: const QuerySolution& solution, PlanYieldPolicySBE* yieldPolicy); + /** + * This method will build an SBE PlanStage tree for QuerySolutionNode 'root' and its + * descendents. + * + * This method is a wrapper around 'build(const QuerySolutionNode*, const PlanStageReqs&)'. + */ std::unique_ptr<sbe::PlanStage> build(const QuerySolutionNode* root) final; PlanStageData getPlanStageData() { @@ -464,6 +577,14 @@ public: } private: + /** + * This method will build an SBE PlanStage tree for QuerySolutionNode 'root' and its + * descendents. + * + * Based on the type of 'root', this method will dispatch to the appropriate buildXXX() method. + * This method will also handle generating calls to getField() to satisfy kField reqs that were + * not satisfied by the buildXXX() method. + */ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> build(const QuerySolutionNode* node, const PlanStageReqs& reqs); @@ -494,7 +615,7 @@ private: std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortCovered( const QuerySolutionNode* root, const PlanStageReqs& reqs); - std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortKeyGeneraror( + std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortKeyGenerator( const QuerySolutionNode* root, const PlanStageReqs& reqs); std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildSortMerge( @@ -539,19 +660,14 @@ private: const QuerySolutionNode* root, const PlanStageReqs& reqs); /** - * Constructs an optimized SBE plan for 'filterNode' in the case that the fields of the - * 'shardKeyPattern' are provided by 'childIxscan'. In this case, the SBE plan for the child + * Constructs an optimized SBE plan for 'root' in the case that the fields of the shard key + * pattern are provided by the child index scan. In this case, the SBE plan for the child * index scan node will fill out slots for the necessary components of the index key. These * slots can be read directly in order to determine the shard key that should be passed to the * 'shardFiltererSlot'. */ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildShardFilterCovered( - const ShardingFilterNode* filterNode, - sbe::value::SlotId shardFiltererSlot, - BSONObj shardKeyPattern, - BSONObj indexKeyPattern, - const QuerySolutionNode* child, - PlanStageReqs childReqs); + const QuerySolutionNode* root, const PlanStageReqs& reqs); std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildGroup( const QuerySolutionNode* root, const PlanStageReqs& reqs); diff --git a/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp b/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp index c44c088ac87..34a6b8a00a6 100644 --- a/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp +++ b/src/mongo/db/query/sbe_stage_builder_coll_scan.cpp @@ -158,7 +158,7 @@ std::unique_ptr<sbe::PlanStage> buildResumeFromRecordIdSubtree( std::move(seekRecordIdExpression)); // Construct a 'seek' branch of the 'union'. If we're succeeded to reposition the cursor, - // the branch will output the 'seekSlot' to start the real scan from, otherwise it will + // the branch will output the 'seekSlot' to start the real scan from, otherwise it will // produce EOF. auto seekBranch = sbe::makeS<sbe::LoopJoinStage>(std::move(projStage), @@ -232,7 +232,7 @@ std::unique_ptr<sbe::PlanStage> buildResumeFromRecordIdSubtree( * Creates a collection scan sub-tree optimized for oplog scans. We can built an optimized scan * when any of the following scenarios apply: * - * 1. There is a predicted on the 'ts' field of the oplog collection. + * 1. There is a predicate on the 'ts' field of the oplog collection. * 1.1 If a lower bound on 'ts' is present, the collection scan will seek directly to the * RecordId of an oplog entry as close to this lower bound as possible without going higher. * 1.2 If the query is *only* a lower bound on 'ts' on a forward scan, every document in the @@ -248,6 +248,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo StageBuilderState& state, const CollectionPtr& collection, const CollectionScanNode* csn, + const std::vector<std::string>& fields, PlanYieldPolicy* yieldPolicy, bool isTailableResumeBranch) { invariant(collection->ns().isOplog()); @@ -258,6 +259,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo // Oplog scan optimizations can only be done for a forward scan. invariant(csn->direction == CollectionScanParams::FORWARD); + auto fieldSlots = state.slotIdGenerator->generateMultiple(fields.size()); + auto resultSlot = state.slotId(); auto recordIdSlot = state.slotId(); @@ -290,9 +293,16 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo // of if we need to track the latest oplog timestamp. const auto shouldTrackLatestOplogTimestamp = (csn->maxRecord || csn->shouldTrackLatestOplogTimestamp); - auto&& [fields, slots, tsSlot] = makeOplogTimestampSlotsIfNeeded( + auto&& [scanFields, scanFieldSlots, tsSlot] = makeOplogTimestampSlotsIfNeeded( state.data->env, state.slotIdGenerator, shouldTrackLatestOplogTimestamp); + bool createScanWithAndWithoutFilter = (csn->filter && csn->stopApplyingFilterAfterFirstMatch); + + if (!createScanWithAndWithoutFilter) { + scanFields.insert(scanFields.end(), fields.begin(), fields.end()); + scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end()); + } + sbe::ScanCallbacks callbacks({}, {}, makeOpenCallbackIfNeeded(collection, csn)); auto stage = sbe::makeS<sbe::ScanStage>(collection->uuid(), resultSlot, @@ -302,8 +312,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo boost::none /* indexKeySlot */, boost::none /* keyPatternSlot */, tsSlot, - std::move(fields), - std::move(slots), + std::move(scanFields), + std::move(scanFieldSlots), seekRecordIdSlot, true /* forward */, yieldPolicy, @@ -342,7 +352,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo auto oObjSlot = state.slotId(); auto minTsSlot = state.slotId(); sbe::value::SlotVector minTsSlots = {minTsSlot, opTypeSlot, oObjSlot}; - std::vector<std::string> fields = {repl::OpTime::kTimestampFieldName.toString(), "op", "o"}; + std::vector<std::string> minTsFields = { + repl::OpTime::kTimestampFieldName.toString(), "op", "o"}; // If the first entry we see in the oplog is the replset initialization, then it doesn't // matter if its timestamp is later than the specified minTs; no events earlier than the @@ -375,7 +386,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo boost::none /* indexKeySlot */, boost::none /* keyPatternSlot */, boost::none /* oplogTsSlot*/, - std::move(fields), + std::move(minTsFields), minTsSlots, /* don't move this */ boost::none, true /* forward */, @@ -421,6 +432,18 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo tsSlot = state.slotId(); auto outputSlots = sbe::makeSV(resultSlot, recordIdSlot, *tsSlot); + if (!createScanWithAndWithoutFilter) { + auto unusedFieldSlots = state.slotIdGenerator->generateMultiple(fieldSlots.size()); + minTsSlots.insert(minTsSlots.end(), unusedFieldSlots.begin(), unusedFieldSlots.end()); + + realSlots.insert(realSlots.end(), fieldSlots.begin(), fieldSlots.end()); + + size_t numFieldSlots = fieldSlots.size(); + fieldSlots = state.slotIdGenerator->generateMultiple(numFieldSlots); + + outputSlots.insert(outputSlots.end(), fieldSlots.begin(), fieldSlots.end()); + } + // Create the union stage. The left branch, which runs first, is our resumability check. stage = sbe::makeS<sbe::UnionStage>( sbe::makeSs(std::move(minTsBranch), std::move(stage)), @@ -454,6 +477,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo relevantSlots.push_back(*tsSlot); } + relevantSlots.insert(relevantSlots.end(), fieldSlots.begin(), fieldSlots.end()); + auto [_, outputStage] = generateFilter(state, csn->filter.get(), {std::move(stage), std::move(relevantSlots)}, @@ -483,7 +508,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo // matches. This RecordId is then used as a starting point of the collection scan in the // inner branch, and the execution will continue from this point further on, without // applying the filter. - if (csn->stopApplyingFilterAfterFirstMatch) { + if (createScanWithAndWithoutFilter) { invariant(!csn->maxRecord); invariant(csn->direction == CollectionScanParams::FORWARD); @@ -491,9 +516,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo resultSlot = state.slotId(); recordIdSlot = state.slotId(); - std::tie(fields, slots, tsSlot) = makeOplogTimestampSlotsIfNeeded( + std::tie(scanFields, scanFieldSlots, tsSlot) = makeOplogTimestampSlotsIfNeeded( state.data->env, state.slotIdGenerator, shouldTrackLatestOplogTimestamp); + scanFields.insert(scanFields.end(), fields.begin(), fields.end()); + scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end()); + stage = sbe::makeS<sbe::LoopJoinStage>( sbe::makeS<sbe::LimitSkipStage>(std::move(stage), 1, boost::none, csn->nodeId()), sbe::makeS<sbe::ScanStage>(collection->uuid(), @@ -504,8 +532,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo boost::none /* indexKeySlot */, boost::none /* keyPatternSlot */, tsSlot, - std::move(fields), - std::move(slots), + std::move(scanFields), + std::move(scanFieldSlots), seekRecordIdSlot, true /* forward */, yieldPolicy, @@ -524,6 +552,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateOptimizedOplo PlanStageSlots outputs; outputs.set(PlanStageSlots::kResult, resultSlot); outputs.set(PlanStageSlots::kRecordId, recordIdSlot); + for (size_t i = 0; i < fields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]); + } return {std::move(stage), std::move(outputs)}; } @@ -540,6 +571,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc StageBuilderState& state, const CollectionPtr& collection, const CollectionScanNode* csn, + const std::vector<std::string>& fields, PlanYieldPolicy* yieldPolicy, bool isTailableResumeBranch) { const auto forward = csn->direction == CollectionScanParams::FORWARD; @@ -548,6 +580,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc invariant(!csn->resumeAfterRecordId || forward); invariant(!csn->resumeAfterRecordId || !csn->tailable); + auto fieldSlots = state.slotIdGenerator->generateMultiple(fields.size()); + auto resultSlot = state.slotId(); auto recordIdSlot = state.slotId(); auto [seekRecordIdSlot, seekRecordIdExpression] = @@ -563,9 +597,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc }(); // See if we need to project out an oplog latest timestamp. - auto&& [fields, slots, tsSlot] = makeOplogTimestampSlotsIfNeeded( + auto&& [scanFields, scanFieldSlots, tsSlot] = makeOplogTimestampSlotsIfNeeded( state.data->env, state.slotIdGenerator, csn->shouldTrackLatestOplogTimestamp); + scanFields.insert(scanFields.end(), fields.begin(), fields.end()); + scanFieldSlots.insert(scanFieldSlots.end(), fieldSlots.begin(), fieldSlots.end()); + sbe::ScanCallbacks callbacks({}, {}, makeOpenCallbackIfNeeded(collection, csn)); auto stage = sbe::makeS<sbe::ScanStage>(collection->uuid(), resultSlot, @@ -575,8 +612,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc boost::none /* indexKeySlot */, boost::none /* keyPatternSlot */, tsSlot, - std::move(fields), - std::move(slots), + std::move(scanFields), + std::move(scanFieldSlots), seekRecordIdSlot, forward, yieldPolicy, @@ -602,6 +639,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc invariant(!csn->stopApplyingFilterAfterFirstMatch); auto relevantSlots = sbe::makeSV(resultSlot, recordIdSlot); + relevantSlots.insert(relevantSlots.end(), fieldSlots.begin(), fieldSlots.end()); auto [_, outputStage] = generateFilter(state, csn->filter.get(), @@ -614,6 +652,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateGenericCollSc PlanStageSlots outputs; outputs.set(PlanStageSlots::kResult, resultSlot); outputs.set(PlanStageSlots::kRecordId, recordIdSlot); + for (size_t i = 0; i < fields.size(); ++i) { + outputs.set(std::make_pair(PlanStageSlots::kField, fields[i]), fieldSlots[i]); + } return {std::move(stage), std::move(outputs)}; } @@ -623,13 +664,15 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateCollScan( StageBuilderState& state, const CollectionPtr& collection, const CollectionScanNode* csn, + const std::vector<std::string>& fields, PlanYieldPolicy* yieldPolicy, bool isTailableResumeBranch) { if (csn->minRecord || csn->maxRecord || csn->stopApplyingFilterAfterFirstMatch) { return generateOptimizedOplogScan( - state, collection, csn, yieldPolicy, isTailableResumeBranch); + state, collection, csn, fields, yieldPolicy, isTailableResumeBranch); } else { - return generateGenericCollScan(state, collection, csn, yieldPolicy, isTailableResumeBranch); + return generateGenericCollScan( + state, collection, csn, fields, yieldPolicy, isTailableResumeBranch); } } } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_coll_scan.h b/src/mongo/db/query/sbe_stage_builder_coll_scan.h index b204c5487a8..3b6e1b861ae 100644 --- a/src/mongo/db/query/sbe_stage_builder_coll_scan.h +++ b/src/mongo/db/query/sbe_stage_builder_coll_scan.h @@ -41,7 +41,10 @@ namespace mongo::stage_builder { class PlanStageSlots; /** - * Generates an SBE plan stage sub-tree implementing an collection scan. + * Generates an SBE plan stage sub-tree implementing an collection scan. 'fields' can be used to + * specify top-level fields that should be retrieved during the scan. For each name in 'fields', + * there will be a corresponding kField slot in the PlanStageSlots object returned with the same + * name. * * On success, a tuple containing the following data is returned: * * A slot to access a fetched document (a resultSlot) @@ -56,6 +59,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateCollScan( StageBuilderState& state, const CollectionPtr& collection, const CollectionScanNode* csn, + const std::vector<std::string>& fields, PlanYieldPolicy* yieldPolicy, bool isTailableResumeBranch); diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index d96f7ea6a92..c1566970b6d 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -76,8 +76,9 @@ struct ExpressionVisitorContext { ExpressionVisitorContext(StageBuilderState& state, EvalStage inputStage, boost::optional<sbe::value::SlotId> optionalRootSlot, - PlanNodeId planNodeId) - : state(state), optionalRootSlot(optionalRootSlot), planNodeId(planNodeId) { + PlanNodeId planNodeId, + const PlanStageSlots* slots = nullptr) + : state(state), optionalRootSlot(optionalRootSlot), planNodeId(planNodeId), slots(slots) { evalStack.emplaceFrame(std::move(inputStage)); } @@ -142,6 +143,9 @@ struct ExpressionVisitorContext { std::stack<VarsFrame> varsFrameStack; // The id of the QuerySolutionNode to which the expression we are converting to SBE is attached. const PlanNodeId planNodeId; + + const PlanStageSlots* slots = nullptr; + // This stack contains slot id for the current element variable of $filter expression. std::stack<sbe::value::SlotId> filterExprSlotIdStack; // We use this counter to track which children of $filter we've already processed. @@ -149,17 +153,23 @@ struct ExpressionVisitorContext { }; std::unique_ptr<sbe::EExpression> generateTraverseHelper( - const sbe::EVariable& inputVar, + std::unique_ptr<sbe::EExpression> inputExpr, const FieldPath& fp, size_t level, - sbe::value::FrameIdGenerator* frameIdGenerator) { + sbe::value::FrameIdGenerator* frameIdGenerator, + boost::optional<sbe::value::SlotId> topLevelFieldSlot = boost::none) { using namespace std::literals; invariant(level < fp.getPathLength()); + tassert(6023417, + "Expected an input expression or top level field", + inputExpr.get() || topLevelFieldSlot.has_value()); // Generate an expression to read a sub-field at the current nested level. auto fieldName = sbe::makeE<sbe::EConstant>(fp.getFieldName(level)); - auto fieldExpr = makeFunction("getField"_sd, inputVar.clone(), std::move(fieldName)); + auto fieldExpr = topLevelFieldSlot + ? makeVariable(*topLevelFieldSlot) + : makeFunction("getField"_sd, std::move(inputExpr), std::move(fieldName)); if (level == fp.getPathLength() - 1) { // For the last level, we can just return the field slot without the need for a @@ -171,7 +181,7 @@ std::unique_ptr<sbe::EExpression> generateTraverseHelper( auto lambdaFrameId = frameIdGenerator->generate(); auto lambdaParam = sbe::EVariable{lambdaFrameId, 0}; - auto resultExpr = generateTraverseHelper(lambdaParam, fp, level + 1, frameIdGenerator); + auto resultExpr = generateTraverseHelper(lambdaParam.clone(), fp, level + 1, frameIdGenerator); auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr)); @@ -186,28 +196,32 @@ std::unique_ptr<sbe::EExpression> generateTraverseHelper( * For the given MatchExpression 'expr', generates a path traversal SBE plan stage sub-tree * implementing the comparison expression. */ -std::unique_ptr<sbe::EExpression> generateTraverse(const sbe::EVariable& inputVar, - bool expectsDocumentInputOnly, - const FieldPath& fp, - sbe::value::FrameIdGenerator* frameIdGenerator) { +std::unique_ptr<sbe::EExpression> generateTraverse( + std::unique_ptr<sbe::EExpression> inputExpr, + bool expectsDocumentInputOnly, + const FieldPath& fp, + sbe::value::FrameIdGenerator* frameIdGenerator, + boost::optional<sbe::value::SlotId> topLevelFieldSlot = boost::none) { size_t level = 0; if (expectsDocumentInputOnly) { - // When we know for sure that 'inputVar' will be a document and _not_ an array (such as - // when traversing the root document), we can generate a simpler expression. - return generateTraverseHelper(inputVar, fp, level, frameIdGenerator); + // When we know for sure that 'inputExpr' will be a document and _not_ an array (such as + // when accessing a field on the root document), we can generate a simpler expression. + return generateTraverseHelper( + std::move(inputExpr), fp, level, frameIdGenerator, topLevelFieldSlot); } else { - // The general case: the value in the 'inputVar' may be an array that will require + tassert(6023418, "Expected an input expression", inputExpr.get()); + // The general case: the value in the 'inputExpr' may be an array that will require // traversal. auto lambdaFrameId = frameIdGenerator->generate(); auto lambdaParam = sbe::EVariable{lambdaFrameId, 0}; - auto resultExpr = generateTraverseHelper(lambdaParam, fp, level, frameIdGenerator); + auto resultExpr = generateTraverseHelper(lambdaParam.clone(), fp, level, frameIdGenerator); auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(lambdaFrameId, std::move(resultExpr)); return makeFunction("traverseP", - inputVar.clone(), + std::move(inputExpr), std::move(lambdaExpr), makeConstant(sbe::value::TypeTags::NumberInt32, 1)); } @@ -1951,28 +1965,44 @@ public: void visit(const ExpressionFieldPath* expr) final { // There's a chance that we've already generated a SBE plan stage tree for this field path, // in which case we avoid regeneration of the same plan stage tree. - if (auto it = _context->state.preGeneratedExprs.find(expr->getFieldPath().fullPath()); - it != _context->state.preGeneratedExprs.end()) { - tassert(6089301, - "Expressions for top-level document or a variable must not be pre-generated", - expr->getFieldPath().getPathLength() != 1 && !expr->isVariableReference()); - if (auto optionalSlot = it->second.getSlot(); optionalSlot) { - _context->pushExpr(*optionalSlot); - } else { - auto preGeneratedExpr = it->second.extractExpr(); - _context->pushExpr(preGeneratedExpr->clone()); - it->second = std::move(preGeneratedExpr); + if (!_context->state.preGeneratedExprs.empty()) { + if (auto it = _context->state.preGeneratedExprs.find(expr->getFieldPath().fullPath()); + it != _context->state.preGeneratedExprs.end()) { + tassert(6089301, + "Expressions for top-level documents / variables must not be pre-generated", + expr->getFieldPath().getPathLength() != 1 && !expr->isVariableReference()); + if (auto optionalSlot = it->second.getSlot(); optionalSlot) { + _context->pushExpr(*optionalSlot); + } else { + auto preGeneratedExpr = it->second.extractExpr(); + _context->pushExpr(preGeneratedExpr->clone()); + it->second = std::move(preGeneratedExpr); + } + return; } - return; } - tassert(6075901, "Must have a valid root slot", _context->optionalRootSlot.has_value()); + boost::optional<sbe::value::SlotId> slotId; + boost::optional<sbe::value::SlotId> topLevelFieldSlot; + boost::optional<FieldPath> fp; + bool expectsDocumentInputOnly = false; - sbe::value::SlotId slotId; + if (expr->getFieldPath().getPathLength() > 1) { + fp = expr->getFieldPathWithoutCurrentPrefix(); + } + + if (expr->getVariableId() == Variables::kRootId && + expr->getFieldPath().getPathLength() > 1) { + slotId = _context->optionalRootSlot; + expectsDocumentInputOnly = true; - if (!Variables::isUserDefinedVariable(expr->getVariableId())) { + if (_context->slots) { + auto topLevelField = std::make_pair(PlanStageSlots::kField, fp->front()); + topLevelFieldSlot = _context->slots->getIfExists(topLevelField); + } + } else if (!Variables::isUserDefinedVariable(expr->getVariableId())) { if (expr->getVariableId() == Variables::kRootId) { - slotId = *(_context->optionalRootSlot); + slotId = _context->optionalRootSlot; } else if (expr->getVariableId() == Variables::kRemoveId) { // For the field paths that begin with "$$REMOVE", we always produce Nothing, // so no traversal is necessary. @@ -1990,7 +2020,7 @@ public: << "Builtin variable '$$" << it->second << "' is not available", variableSlot.has_value()); - slotId = *variableSlot; + slotId = variableSlot; } } else { auto it = _context->environment.find(expr->getVariableId()); @@ -2002,19 +2032,27 @@ public: } if (expr->getFieldPath().getPathLength() == 1) { + tassert(6023419, "Must have a valid slot", slotId.has_value()); + // A solo variable reference (e.g.: "$$ROOT" or "$$myvar") that doesn't need any // traversal. - _context->pushExpr(slotId); + _context->pushExpr(*slotId); return; } - // Dereference a dotted path, which may contain arrays requiring implicit traversal. - const bool expectsDocumentInputOnly = slotId == *(_context->optionalRootSlot); + tassert(6023420, + "Must have a valid root slot or field slot", + slotId.has_value() || topLevelFieldSlot.has_value()); + + auto inputExpr = + slotId ? sbe::makeE<sbe::EVariable>(*slotId) : std::unique_ptr<sbe::EExpression>{}; - auto resultExpr = generateTraverse(sbe::EVariable{slotId}, + // Dereference a dotted path, which may contain arrays requiring implicit traversal. + auto resultExpr = generateTraverse(std::move(inputExpr), expectsDocumentInputOnly, - expr->getFieldPathWithoutCurrentPrefix(), - _context->state.frameIdGenerator); + *fp, + _context->state.frameIdGenerator, + topLevelFieldSlot); _context->pushExpr(std::move(resultExpr)); } @@ -4126,8 +4164,9 @@ EvalExprStagePair generateExpression(StageBuilderState& state, const Expression* expr, EvalStage stage, boost::optional<sbe::value::SlotId> optionalRootSlot, - PlanNodeId planNodeId) { - ExpressionVisitorContext context(state, std::move(stage), optionalRootSlot, planNodeId); + PlanNodeId planNodeId, + const PlanStageSlots* slots) { + ExpressionVisitorContext context(state, std::move(stage), optionalRootSlot, planNodeId, slots); ExpressionPreVisitor preVisitor{&context}; ExpressionInVisitor inVisitor{&context}; diff --git a/src/mongo/db/query/sbe_stage_builder_expression.h b/src/mongo/db/query/sbe_stage_builder_expression.h index 3d148113f89..78b4774f807 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.h +++ b/src/mongo/db/query/sbe_stage_builder_expression.h @@ -38,6 +38,8 @@ #include "mongo/db/query/sbe_stage_builder_helpers.h" namespace mongo::stage_builder { +class PlanStageSlots; + /** * Translates an input Expression into an SBE EExpression. The 'stage' parameter provides the input * subtree to build on top of. @@ -46,7 +48,8 @@ EvalExprStagePair generateExpression(StageBuilderState& state, const Expression* expr, EvalStage stage, boost::optional<sbe::value::SlotId> optionalRootSlot, - PlanNodeId planNodeId); + PlanNodeId planNodeId, + const PlanStageSlots* slots = nullptr); /** * Generate an EExpression that converts a value (contained in a variable bound to 'branchRef') that diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index affc05691b1..c8acb608a42 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -958,27 +958,25 @@ bool indexKeyConsistencyCheckCallback(OperationContext* opCtx, return true; } -std::tuple<sbe::value::SlotId, sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>> -makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, - sbe::value::SlotId seekKeySlot, - sbe::value::SlotId snapshotIdSlot, - sbe::value::SlotId indexIdSlot, - sbe::value::SlotId indexKeySlot, - sbe::value::SlotId indexKeyPatternSlot, - const CollectionPtr& collToFetch, - StringMap<const IndexAccessMethod*> iamMap, - PlanNodeId planNodeId, - sbe::value::SlotVector slotsToForward, - sbe::value::SlotIdGenerator& slotIdGenerator) { +std::unique_ptr<sbe::PlanStage> makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, + sbe::value::SlotId resultSlot, + sbe::value::SlotId recordIdSlot, + std::vector<std::string> fields, + sbe::value::SlotVector fieldSlots, + sbe::value::SlotId seekKeySlot, + sbe::value::SlotId snapshotIdSlot, + sbe::value::SlotId indexIdSlot, + sbe::value::SlotId indexKeySlot, + sbe::value::SlotId indexKeyPatternSlot, + const CollectionPtr& collToFetch, + StringMap<const IndexAccessMethod*> iamMap, + PlanNodeId planNodeId, + sbe::value::SlotVector slotsToForward) { // It is assumed that we are generating a fetch loop join over the main collection. If we are // generating a fetch over a secondary collection, it is the responsibility of a parent node // in the QSN tree to indicate which collection we are fetching over. tassert(6355301, "Cannot fetch from a collection that doesn't exist", collToFetch); - auto resultSlot = slotIdGenerator.generate(); - auto recordIdSlot = slotIdGenerator.generate(); - - using namespace std::placeholders; sbe::ScanCallbacks callbacks( indexKeyCorruptionCheckCallback, [=](auto&& arg1, auto&& arg2, auto&& arg3, auto&& arg4, auto&& arg5, auto&& arg6) { @@ -995,8 +993,8 @@ makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, indexKeySlot, indexKeyPatternSlot, boost::none, - std::vector<std::string>{}, - sbe::makeSV(), + std::move(fields), + std::move(fieldSlots), seekKeySlot, true, nullptr, @@ -1005,15 +1003,13 @@ makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, // Get the recordIdSlot from the outer side (e.g., IXSCAN) and feed it to the inner side, // limiting the result set to 1 row. - auto stage = sbe::makeS<sbe::LoopJoinStage>( + return sbe::makeS<sbe::LoopJoinStage>( std::move(inputStage), sbe::makeS<sbe::LimitSkipStage>(std::move(scanStage), 1, boost::none, planNodeId), std::move(slotsToForward), sbe::makeSV(seekKeySlot, snapshotIdSlot, indexIdSlot, indexKeySlot, indexKeyPatternSlot), nullptr, planNodeId); - - return {resultSlot, recordIdSlot, std::move(stage)}; } sbe::value::SlotId StageBuilderState::registerInputParamSlot( @@ -1121,4 +1117,172 @@ std::unique_ptr<sbe::PlanStage> rehydrateIndexKey(std::unique_ptr<sbe::PlanStage return sbe::makeProjectStage(std::move(stage), nodeId, resultSlot, std::move(keyExpr)); } +/** + * For covered projections, each of the projection field paths represent respective index key. To + * rehydrate index keys into the result object, we first need to convert projection AST into + * 'IndexKeyPatternTreeNode' structure. Context structure and visitors below are used for this + * purpose. + */ +struct IndexKeysBuilderContext { + // Contains resulting tree of index keys converted from projection AST. + IndexKeyPatternTreeNode root; + + // Full field path of the currently visited projection node. + std::vector<StringData> currentFieldPath; + + // Each projection node has a vector of field names. This stack contains indexes of the + // currently visited field names for each of the projection nodes. + std::vector<size_t> currentFieldIndex; +}; + +/** + * Covered projections are always inclusion-only, so we ban all other operators. + */ +class IndexKeysBuilder : public projection_ast::ProjectionASTConstVisitor { +public: + using projection_ast::ProjectionASTConstVisitor::visit; + + IndexKeysBuilder(IndexKeysBuilderContext* context) : _context{context} {} + + void visit(const projection_ast::ProjectionPositionalASTNode* node) final { + tasserted(5474501, "Positional projection is not allowed in simple or covered projections"); + } + + void visit(const projection_ast::ProjectionSliceASTNode* node) final { + tasserted(5474502, "$slice is not allowed in simple or covered projections"); + } + + void visit(const projection_ast::ProjectionElemMatchASTNode* node) final { + tasserted(5474503, "$elemMatch is not allowed in simple or covered projections"); + } + + void visit(const projection_ast::ExpressionASTNode* node) final { + tasserted(5474504, "Expressions are not allowed in simple or covered projections"); + } + + void visit(const projection_ast::MatchExpressionASTNode* node) final { + tasserted( + 5474505, + "$elemMatch / positional projection are not allowed in simple or covered projections"); + } + + void visit(const projection_ast::BooleanConstantASTNode* node) override {} + +protected: + IndexKeysBuilderContext* _context; +}; + +class IndexKeysPreBuilder final : public IndexKeysBuilder { +public: + using IndexKeysBuilder::IndexKeysBuilder; + using IndexKeysBuilder::visit; + + void visit(const projection_ast::ProjectionPathASTNode* node) final { + _context->currentFieldIndex.push_back(0); + _context->currentFieldPath.emplace_back(node->fieldNames().front()); + } +}; + +class IndexKeysInBuilder final : public IndexKeysBuilder { +public: + using IndexKeysBuilder::IndexKeysBuilder; + using IndexKeysBuilder::visit; + + void visit(const projection_ast::ProjectionPathASTNode* node) final { + auto& currentIndex = _context->currentFieldIndex.back(); + currentIndex++; + _context->currentFieldPath.back() = node->fieldNames()[currentIndex]; + } +}; + +class IndexKeysPostBuilder final : public IndexKeysBuilder { +public: + using IndexKeysBuilder::IndexKeysBuilder; + using IndexKeysBuilder::visit; + + void visit(const projection_ast::ProjectionPathASTNode* node) final { + _context->currentFieldIndex.pop_back(); + _context->currentFieldPath.pop_back(); + } + + void visit(const projection_ast::BooleanConstantASTNode* constantNode) final { + if (!constantNode->value()) { + // Even though only inclusion is allowed in covered projection, there still can be + // {_id: 0} component. We do not need to generate any nodes for it. + return; + } + + // Insert current field path into the index keys tree if it does not exist yet. + auto* node = &_context->root; + for (const auto& part : _context->currentFieldPath) { + if (auto it = node->children.find(part); it != node->children.end()) { + node = it->second.get(); + } else { + node = node->emplace(part); + } + } + } +}; + +IndexKeyPatternTreeNode buildPatternTree(const projection_ast::Projection& projection) { + IndexKeysBuilderContext context; + IndexKeysPreBuilder preVisitor{&context}; + IndexKeysInBuilder inVisitor{&context}; + IndexKeysPostBuilder postVisitor{&context}; + + projection_ast::ProjectionASTConstWalker walker{&preVisitor, &inVisitor, &postVisitor}; + + tree_walker::walk<true, projection_ast::ASTNode>(projection.root(), &walker); + + return std::move(context.root); +} + +std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectTopLevelFields( + std::unique_ptr<sbe::PlanStage> stage, + const std::vector<std::string>& fields, + sbe::value::SlotId resultSlot, + PlanNodeId nodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + // 'outputSlots' will match the order of 'fields'. + sbe::value::SlotVector outputSlots; + outputSlots.reserve(fields.size()); + + sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects; + for (size_t i = 0; i < fields.size(); ++i) { + const auto& field = fields[i]; + auto slot = slotIdGenerator->generate(); + auto getFieldExpr = + makeFunction("getField"_sd, makeVariable(resultSlot), makeConstant(field)); + projects.insert({slot, std::move(getFieldExpr)}); + outputSlots.emplace_back(slot); + } + + if (!projects.empty()) { + stage = sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projects), nodeId); + } + + return {std::move(stage), std::move(outputSlots)}; +} + +std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectNothingToSlots( + std::unique_ptr<sbe::PlanStage> stage, + size_t n, + PlanNodeId nodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + if (n == 0) { + return {std::move(stage), sbe::value::SlotVector{}}; + } + + auto outputSlots = slotIdGenerator->generateMultiple(n); + + sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projects; + for (size_t i = 0; i < n; ++i) { + projects.insert( + {outputSlots[i], sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0)}); + } + + stage = sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projects), nodeId); + + return {std::move(stage), std::move(outputSlots)}; +} } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index f3b5d9ca455..634005c504c 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -44,6 +44,10 @@ #include "mongo/db/query/sbe_stage_builder_eval_frame.h" #include "mongo/db/query/stage_types.h" +namespace mongo::projection_ast { +class Projection; +} + namespace mongo::stage_builder { std::unique_ptr<sbe::EExpression> makeUnaryOp(sbe::EPrimUnary::Op unaryOp, @@ -524,19 +528,20 @@ std::unique_ptr<sbe::EExpression> makeLocalBind(sbe::value::FrameIdGenerator* fr return sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(innerExpr)); } - -std::tuple<sbe::value::SlotId, sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>> -makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, - sbe::value::SlotId seekKeySlot, - sbe::value::SlotId snapshotIdSlot, - sbe::value::SlotId indexIdSlot, - sbe::value::SlotId indexKeySlot, - sbe::value::SlotId indexKeyPatternSlot, - const CollectionPtr& collToFetch, - StringMap<const IndexAccessMethod*> iamMap, - PlanNodeId planNodeId, - sbe::value::SlotVector slotsToForward, - sbe::value::SlotIdGenerator& slotIdGenerator); +std::unique_ptr<sbe::PlanStage> makeLoopJoinForFetch(std::unique_ptr<sbe::PlanStage> inputStage, + sbe::value::SlotId resultSlot, + sbe::value::SlotId recordIdSlot, + std::vector<std::string> fields, + sbe::value::SlotVector fieldSlots, + sbe::value::SlotId seekKeySlot, + sbe::value::SlotId snapshotIdSlot, + sbe::value::SlotId indexIdSlot, + sbe::value::SlotId indexKeySlot, + sbe::value::SlotId indexKeyPatternSlot, + const CollectionPtr& collToFetch, + StringMap<const IndexAccessMethod*> iamMap, + PlanNodeId planNodeId, + sbe::value::SlotVector slotsToForward); /** * Trees generated by 'generateFilter' maintain state during execution. There are two types of state @@ -1010,4 +1015,60 @@ std::unique_ptr<sbe::PlanStage> rehydrateIndexKey(std::unique_ptr<sbe::PlanStage const sbe::value::SlotVector& indexKeySlots, sbe::value::SlotId resultSlot); +IndexKeyPatternTreeNode buildPatternTree(const projection_ast::Projection& projection); + +/** + * This method retrieves the values of the specified top-level fields ('fields') from 'resultSlot' + * and stores the values into slots. + * + * This method returns a pair containing: (1) the updated SBE plan stage tree and; (2) a vector of + * the slots ('slots') containing the field values. + * + * The order of slots in 'slots' will match the order of fields in 'fields'. + */ +std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectTopLevelFields( + std::unique_ptr<sbe::PlanStage> stage, + const std::vector<std::string>& fields, + sbe::value::SlotId resultSlot, + PlanNodeId nodeId, + sbe::value::SlotIdGenerator* slotIdGenerator); + +/** + * This method projects the constant value Nothing to multiple slots (the specific number of slots + * being specified by parameter 'n'). + * + * This method returns a pair containing: (1) the updated SBE plan stage tree and; (2) a vector of + * slots ('slots') containing Nothing. + * + * The number of slots in 'slots' will always be equal to parameter 'n'. + */ +std::pair<std::unique_ptr<sbe::PlanStage>, sbe::value::SlotVector> projectNothingToSlots( + std::unique_ptr<sbe::PlanStage> stage, + size_t n, + PlanNodeId nodeId, + sbe::value::SlotIdGenerator* slotIdGenerator); + +template <typename T, typename FuncT> +std::vector<T> filterVector(std::vector<T> vec, FuncT fn) { + std::vector<T> result; + std::copy_if(std::make_move_iterator(vec.begin()), + std::make_move_iterator(vec.end()), + std::back_inserter(result), + fn); + return result; +} + +template <typename T, typename FuncT> +std::pair<std::vector<T>, std::vector<T>> splitVector(std::vector<T> vec, FuncT fn) { + std::pair<std::vector<T>, std::vector<T>> result; + for (size_t i = 0; i < vec.size(); ++i) { + if (fn(vec[i])) { + result.first.emplace_back(std::move(vec[i])); + } else { + result.second.emplace_back(std::move(vec[i])); + } + } + return result; +} + } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_index_scan.cpp b/src/mongo/db/query/sbe_stage_builder_index_scan.cpp index 51d37ceccf4..bfc52d055c5 100644 --- a/src/mongo/db/query/sbe_stage_builder_index_scan.cpp +++ b/src/mongo/db/query/sbe_stage_builder_index_scan.cpp @@ -696,7 +696,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan( PlanYieldPolicy* yieldPolicy, StringMap<const IndexAccessMethod*>* iamMap, bool needsCorruptionCheck) { - auto indexName = ixn->index.identifier.catalogName; auto descriptor = collection->getIndexCatalog()->findIndexByName(state.opCtx, indexName); tassert(5483200, @@ -855,8 +854,16 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan( stage = outputStage.extractStage(ixn->nodeId()); } - outputs.setIndexKeySlots(makeIndexKeyOutputSlotsMatchingParentReqs( - ixn->index.keyPattern, originalIndexKeyBitset, indexKeyBitset, indexKeySlots)); + size_t i = 0; + size_t slotIdx = 0; + for (const auto& elt : ixn->index.keyPattern) { + StringData name = elt.fieldNameStringData(); + if (indexKeyBitset.test(i)) { + outputs.set(std::make_pair(PlanStageSlots::kKey, name), indexKeySlots[slotIdx]); + ++slotIdx; + } + ++i; + } return {std::move(stage), std::move(outputs)}; } @@ -981,8 +988,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith // Whenever possible we should prefer building simplified single interval index scan plans in // order to get the best performance. if (canGenerateSingleIntervalIndexScan(ixn->iets)) { - auto makeSlot = [&](const bool cond, - const StringData slotKey) -> boost::optional<sbe::value::SlotId> { + auto makeSlot = + [&](const bool cond, + const PlanStageSlots::Name slotKey) -> boost::optional<sbe::value::SlotId> { if (!cond) return boost::none; @@ -1024,10 +1032,9 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith auto optimizedIndexScanSlots = optimizedIndexKeySlots; auto branchOutputSlots = outputIndexKeySlots; - auto makeSlotsForThenElseBranches = - [&](const bool cond, - const StringData slotKey) -> std::tuple<boost::optional<sbe::value::SlotId>, - boost::optional<sbe::value::SlotId>> { + auto makeSlotsForThenElseBranches = [&](const bool cond, const PlanStageSlots::Name slotKey) + -> std::tuple<boost::optional<sbe::value::SlotId>, + boost::optional<sbe::value::SlotId>> { if (!cond) return {boost::none, boost::none}; @@ -1145,9 +1152,6 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith stage = outputStage.extractStage(ixn->nodeId()); } - outputs.setIndexKeySlots(makeIndexKeyOutputSlotsMatchingParentReqs( - ixn->index.keyPattern, originalIndexKeyBitset, indexKeyBitset, outputIndexKeySlots)); - state.data->indexBoundsEvaluationInfos.emplace_back( IndexBoundsEvaluationInfo{ixn->index, accessMethod->getSortedDataInterface()->getKeyStringVersion(), @@ -1156,6 +1160,17 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWith std::move(ixn->iets), std::move(parameterizedScanSlots)}); + size_t i = 0; + size_t slotIdx = 0; + for (const auto& elt : ixn->index.keyPattern) { + StringData name = elt.fieldNameStringData(); + if (indexKeyBitset.test(i)) { + outputs.set(std::make_pair(PlanStageSlots::kKey, name), outputIndexKeySlots[slotIdx]); + ++slotIdx; + } + ++i; + } + return {std::move(stage), std::move(outputs)}; } } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_index_scan.h b/src/mongo/db/query/sbe_stage_builder_index_scan.h index 29949cc1169..cded81563ac 100644 --- a/src/mongo/db/query/sbe_stage_builder_index_scan.h +++ b/src/mongo/db/query/sbe_stage_builder_index_scan.h @@ -48,14 +48,10 @@ using IndexIntervals = std::vector<std::pair<std::unique_ptr<KeyString::Value>, std::unique_ptr<KeyString::Value>>>; /** - * This method generates an SBE plan stage tree implementing an index scan. It returns a tuple - * containing: (1) a slot produced by the index scan that holds the record ID ('recordIdSlot'); - * (2) a slot vector produced by the index scan which hold parts of the index key ('indexKeySlots'); - * and (3) the SBE plan stage tree. 'indexKeySlots' will only contain slots for the parts of the - * index key specified by the 'indexKeysToInclude' bitset. - * - * If the caller provides a slot ID for the 'returnKeySlot' parameter, this method will populate - * the specified slot with the rehydrated index key for each record. + * This method returns a pair containing: (1) an SBE plan stage tree implementing an index scan; + * and (2) a PlanStageSlots object containing a kRecordId slot, possibly some other kMeta slots, + * and slots produced by the index scan that correspond to parts of the index key specified by + * the 'indexKeyBitset' bitset. */ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScan( StageBuilderState& state, @@ -87,40 +83,12 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> packIndexIntervalsInSbeArray( /** * Constructs a generic multi-interval index scan. Depending on the intervals will either execute - * the optimized or the generic index scan subplan. The generated subtree will have - * the following form: + * the optimized or the generic index scan subplan. * - * branch {isGenericScanSlot} [recordIdSlot, resultSlot, ...] - * then - * filter {isRecordId(resultSlot)} - * lspool sp1 [resultSlot] {!isRecordId(resultSlot)} - * union [resultSlot] - project [startKeySlot = anchorSlot, unusedVarSlot0 = Nothing, ...] - * limit 1 - * coscan - * [checkBoundsSlot] - * nlj [] [seekKeySlot] - * left - * sspool sp1 [seekKeySlot] - * right - * chkbounds resultSlot recordIdSlot checkBoundsSlot - * nlj [] [lowKeySlot] - * left - * project [lowKeySlot = seekKeySlot] - * limit 1 - * coscan - * right - * ixseek lowKeySlot resultSlot recordIdSlot [] @coll @index - * else - * nlj [] [lowKeySlot, highKeySlot] - * left - * project [lowKeySlot = getField (unwindSlot, "l"), - * highKeySlot = getField (unwindSlot, "h")] - * unwind unwindSlot indexSlot boundsSlot false - * limit 1 - * coscan - * right - * ixseek lowKeySlot highKeySlot recordIdSlot [] @coll @index + * This method returns a pair containing: (1) an SBE plan stage tree implementing a generic multi- + * interval index scan; and (2) a PlanStageSlots object containing a kRecordId slot, possibly some + * other kMeta slots, and slots produced by the index scan that correspond to parts of the index + * key specified by the 'indexKeyBitset' bitset. */ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> generateIndexScanWithDynamicBounds( StageBuilderState& state, diff --git a/src/mongo/db/query/sbe_stage_builder_lookup.cpp b/src/mongo/db/query/sbe_stage_builder_lookup.cpp index 775090c9cdd..b4c431f26f5 100644 --- a/src/mongo/db/query/sbe_stage_builder_lookup.cpp +++ b/src/mongo/db/query/sbe_stage_builder_lookup.cpp @@ -920,17 +920,21 @@ std::pair<SlotId, std::unique_ptr<sbe::PlanStage>> buildIndexJoinLookupStage( // stored in 'foreignRecordSlot'. We also pass in 'snapshotIdSlot', 'indexIdSlot', // 'indexKeySlot' and 'indexKeyPatternSlot' to perform index consistency check during the // seek. - auto [foreignRecordSlot, __, scanNljStage] = makeLoopJoinForFetch(std::move(ixScanNljStage), - foreignRecordIdSlot, - snapshotIdSlot, - indexIdSlot, - indexKeySlot, - indexKeyPatternSlot, - foreignColl, - iamMap, - nodeId, - makeSV() /* slotsToForward */, - slotIdGenerator); + auto foreignRecordSlot = slotIdGenerator.generate(); + auto scanNljStage = makeLoopJoinForFetch(std::move(ixScanNljStage), + foreignRecordSlot, + slotIdGenerator.generate() /* unused recordId slot */, + std::vector<std::string>{}, + makeSV(), + foreignRecordIdSlot, + snapshotIdSlot, + indexIdSlot, + indexKeySlot, + indexKeyPatternSlot, + foreignColl, + iamMap, + nodeId, + makeSV() /* slotsToForward */); // 'buildForeignMatches()' filters the foreign records, returned by the index scan, to match // those in 'localKeysSetSlot'. This is necessary because some values are encoded with the same diff --git a/src/mongo/db/repl/SConscript b/src/mongo/db/repl/SConscript index dd0aac54126..ad8be47af0a 100644 --- a/src/mongo/db/repl/SConscript +++ b/src/mongo/db/repl/SConscript @@ -1488,6 +1488,7 @@ env.Library( ], LIBDEPS=[ '$BUILD_DIR/mongo/client/fetcher', + '$BUILD_DIR/mongo/executor/async_rpc', 'primary_only_service', 'tenant_migration_access_blocker', 'tenant_migration_statistics', @@ -1717,6 +1718,7 @@ if wiredtiger: 'tenant_migration_recipient_access_blocker_test.cpp', 'tenant_migration_recipient_entry_helpers_test.cpp', 'tenant_migration_recipient_service_test.cpp', + 'tenant_migration_recipient_service_shard_merge_test.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/bson/mutable/mutable_bson', diff --git a/src/mongo/db/repl/dbcheck.cpp b/src/mongo/db/repl/dbcheck.cpp index aa295ce9ca7..0b0be849a5d 100644 --- a/src/mongo/db/repl/dbcheck.cpp +++ b/src/mongo/db/repl/dbcheck.cpp @@ -364,14 +364,26 @@ Status dbCheckBatchOnSecondary(OperationContext* opCtx, // Set up the hasher, boost::optional<DbCheckHasher> hasher; try { - auto lockMode = MODE_S; - if (entry.getReadTimestamp()) { - lockMode = MODE_IS; - opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kProvided, - entry.getReadTimestamp()); + // We may not have a read timestamp if the dbCheck command was run on an older version of + // the server with snapshotRead:false. Since we don't implement this feature, we'll log an + // error about skipping the batch to ensure an operator notices. + if (!entry.getReadTimestamp().has_value()) { + auto logEntry = + dbCheckErrorHealthLogEntry(entry.getNss(), + "dbCheck failed", + OplogEntriesEnum::Batch, + Status{ErrorCodes::Error(6769502), + "no readTimestamp in oplog entry. Ensure dbCheck " + "command is not using snapshotRead:false"}, + entry.toBSON()); + HealthLog::get(opCtx).log(*logEntry); + return Status::OK(); } - AutoGetCollection coll(opCtx, entry.getNss(), lockMode); + opCtx->recoveryUnit()->setTimestampReadSource(RecoveryUnit::ReadSource::kProvided, + entry.getReadTimestamp()); + + AutoGetCollection coll(opCtx, entry.getNss(), MODE_IS); const auto& collection = coll.getCollection(); if (!collection) { @@ -434,12 +446,11 @@ Status dbCheckOplogCommand(OperationContext* opCtx, if (!opCtx->writesAreReplicated()) { opTime = entry.getOpTime(); } - auto type = OplogEntries_parse(IDLParserContext("type"), cmd.getStringField("type")); - IDLParserContext ctx("o"); - + const auto type = OplogEntries_parse(IDLParserContext("type"), cmd.getStringField("type")); + const IDLParserContext ctx("o", false /*apiStrict*/, entry.getTid()); switch (type) { case OplogEntriesEnum::Batch: { - auto invocation = DbCheckOplogBatch::parse(ctx, cmd); + const auto invocation = DbCheckOplogBatch::parse(ctx, cmd); return dbCheckBatchOnSecondary(opCtx, opTime, invocation); } case OplogEntriesEnum::Collection: { diff --git a/src/mongo/db/repl/dbcheck.idl b/src/mongo/db/repl/dbcheck.idl index 9d3b20b68f5..5ded9ce6daf 100644 --- a/src/mongo/db/repl/dbcheck.idl +++ b/src/mongo/db/repl/dbcheck.idl @@ -38,42 +38,6 @@ imports: - "mongo/db/write_concern_options.idl" server_parameters: - dbCheckCollectionTryLockTimeoutMillis: - description: 'Timeout to acquire the collection for processing a dbCheck batch. Each subsequent attempt doubles the timeout' - set_at: [ startup, runtime ] - cpp_vartype: 'AtomicWord<int>' - cpp_varname: gDbCheckCollectionTryLockTimeoutMillis - default: 10 - validator: - gte: 1 - lte: 10000 - dbCheckCollectionTryLockMaxAttempts: - description: 'Maximum number of attempts with backoff to acquire the collection lock for processing a dbCheck batch' - set_at: [ startup, runtime ] - cpp_vartype: 'AtomicWord<int>' - cpp_varname: gDbCheckCollectionTryLockMaxAttempts - default: 5 - validator: - gte: 1 - lte: 20 - dbCheckCollectionTryLockMinBackoffMillis: - description: 'Initial backoff on failure to acquire the collection lock for processing a dbCheck batch. Grows exponentially' - set_at: [ startup, runtime ] - cpp_vartype: 'AtomicWord<int>' - cpp_varname: gDbCheckCollectionTryLockMinBackoffMillis - default: 10 - validator: - gte: 2 - lte: 60000 - dbCheckCollectionTryLockMaxBackoffMillis: - description: 'Maximum exponential backoff on failure to acquire the collection lock for processing a dbCheck batch.' - set_at: [ startup, runtime ] - cpp_vartype: 'AtomicWord<int>' - cpp_varname: gDbCheckCollectionTryLockMaxBackoffMillis - default: 60000 - validator: - gte: 20 - lte: 120000 dbCheckHealthLogEveryNBatches: description: 'Emit an info-severity health log batch every N batches processed' set_at: [ startup, runtime ] diff --git a/src/mongo/db/repl/noop_writer.cpp b/src/mongo/db/repl/noop_writer.cpp index 0b94f0dc60a..c34fc425030 100644 --- a/src/mongo/db/repl/noop_writer.cpp +++ b/src/mongo/db/repl/noop_writer.cpp @@ -137,9 +137,9 @@ Status NoopWriter::startWritingPeriodicNoops(OpTime lastKnownOpTime) { _noopRunner = std::make_unique<PeriodicNoopRunner>(_writeInterval, [this](OperationContext* opCtx) { // Noop writes are critical for the cluster stability, so we mark it as having Immediate - // priority. As a result it will skip both flow control and normal ticket acquisition. - SetTicketAquisitionPriorityForLock priority(opCtx, - AdmissionContext::Priority::kImmediate); + // priority. As a result it will skip both flow control and waiting for ticket + // acquisition. + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); _writeNoop(opCtx); }); return Status::OK(); diff --git a/src/mongo/db/repl/oplog_applier_impl.cpp b/src/mongo/db/repl/oplog_applier_impl.cpp index 6b1121b059a..f7f01d12988 100644 --- a/src/mongo/db/repl/oplog_applier_impl.cpp +++ b/src/mongo/db/repl/oplog_applier_impl.cpp @@ -317,9 +317,9 @@ void OplogApplierImpl::_run(OplogBuffer* oplogBuffer) { OperationContext& opCtx = *opCtxPtr; // The oplog applier is crucial for stability of the replica set. As a result we mark it as - // having Immediate priority. This makes the operation skip ticket acquisition and flow - // control. - SetTicketAquisitionPriorityForLock priority(&opCtx, AdmissionContext::Priority::kImmediate); + // having Immediate priority. This makes the operation skip waiting for ticket acquisition + // and flow control. + SetAdmissionPriorityForLock priority(&opCtx, AdmissionContext::Priority::kImmediate); // For pausing replication in tests. if (MONGO_unlikely(rsSyncApplyStop.shouldFail())) { @@ -432,9 +432,10 @@ void scheduleWritesToOplogAndChangeCollection(OperationContext* opCtx, auto opCtx = cc().makeOperationContext(); // Oplog writes are crucial to the stability of the replica set. We mark the operations - // as having Immediate priority so that it skips ticket acquisition and flow control. - SetTicketAquisitionPriorityForLock priority(opCtx.get(), - AdmissionContext::Priority::kImmediate); + // as having Immediate priority so that it skips waiting for ticket acquisition and flow + // control. + SetAdmissionPriorityForLock priority(opCtx.get(), + AdmissionContext::Priority::kImmediate); UnreplicatedWritesBlock uwb(opCtx.get()); ShouldNotConflictWithSecondaryBatchApplicationBlock shouldNotConflictBlock( @@ -589,10 +590,10 @@ StatusWith<OpTime> OplogApplierImpl::_applyOplogBatch(OperationContext* opCtx, auto opCtx = cc().makeOperationContext(); // Applying an Oplog batch is crucial to the stability of the Replica Set. We - // mark it as having Immediate priority so that it skips ticket acquisition and - // flow control. - SetTicketAquisitionPriorityForLock priority( - opCtx.get(), AdmissionContext::Priority::kImmediate); + // mark it as having Immediate priority so that it skips waiting for ticket + // acquisition and flow control. + SetAdmissionPriorityForLock priority(opCtx.get(), + AdmissionContext::Priority::kImmediate); opCtx->setEnforceConstraints(false); diff --git a/src/mongo/db/repl/repl_set_request_votes.cpp b/src/mongo/db/repl/repl_set_request_votes.cpp index f80027813f9..5f07c9e4d98 100644 --- a/src/mongo/db/repl/repl_set_request_votes.cpp +++ b/src/mongo/db/repl/repl_set_request_votes.cpp @@ -61,9 +61,9 @@ private: uassertStatusOK(status); // Operations that are part of Replica Set elections are crucial to the stability of the - // cluster. Marking it as having Immediate priority will make it skip ticket acquisition and - // Flow Control. - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); + // cluster. Marking it as having Immediate priority will make it skip waiting for ticket + // acquisition and Flow Control. + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); ReplSetRequestVotesResponse response; status = ReplicationCoordinator::get(opCtx)->processReplSetRequestVotes( opCtx, parsedArgs, &response); diff --git a/src/mongo/db/repl/replication_consistency_markers_impl.cpp b/src/mongo/db/repl/replication_consistency_markers_impl.cpp index b06238cb0d9..1e44889baf1 100644 --- a/src/mongo/db/repl/replication_consistency_markers_impl.cpp +++ b/src/mongo/db/repl/replication_consistency_markers_impl.cpp @@ -418,12 +418,11 @@ ReplicationConsistencyMarkersImpl::refreshOplogTruncateAfterPointIfPrimary( } ON_BLOCK_EXIT([&] { opCtx->recoveryUnit()->setPrepareConflictBehavior(originalBehavior); }); - // Exempt storage ticket acquisition in order to avoid starving upstream requests waiting - // for durability. SERVER-60682 is an example with more pending prepared transactions than - // storage tickets; the transaction coordinator could not persist the decision and - // had to unnecessarily wait for prepared transactions to expire to make forward progress. - SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx, - AdmissionContext::Priority::kImmediate); + // Exempt waiting for storage ticket acquisition in order to avoid starving upstream requests + // waiting for durability. SERVER-60682 is an example with more pending prepared transactions + // than storage tickets; the transaction coordinator could not persist the decision and had to + // unnecessarily wait for prepared transactions to expire to make forward progress. + SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate); // The locks necessary to write to the oplog truncate after point's collection and read from the // oplog collection must be taken up front so that the mutex can also be taken around both diff --git a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp index eea20ed70ee..8fc86f695cf 100644 --- a/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp +++ b/src/mongo/db/repl/replication_coordinator_external_state_impl.cpp @@ -855,7 +855,7 @@ void ReplicationCoordinatorExternalStateImpl::_stopAsyncUpdatesOfAndClearOplogTr // As opCtx does not expose a method to allow skipping flow control on purpose we mark the // operation as having Immediate priority. This will skip flow control and ticket acquisition. // It is fine to do this since the system is essentially shutting down at this point. - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); // Tell the system to stop updating the oplogTruncateAfterPoint asynchronously and to go // back to using last applied to update repl's durable timestamp instead of the truncate diff --git a/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp b/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp index e50508e9f35..57959db018b 100644 --- a/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp +++ b/src/mongo/db/repl/replication_coordinator_impl_elect_v1.cpp @@ -339,8 +339,7 @@ void ReplicationCoordinatorImpl::ElectionState::_writeLastVoteForMyElection( // Any operation that occurs as part of an election process is critical to the operation of // the cluster. We mark the operation as having Immediate priority to skip ticket // acquisition and flow control. - SetTicketAquisitionPriorityForLock priority(opCtx.get(), - AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock priority(opCtx.get(), AdmissionContext::Priority::kImmediate); LOGV2(6015300, "Storing last vote document in local storage for my election", diff --git a/src/mongo/db/repl/replication_coordinator_test_fixture.cpp b/src/mongo/db/repl/replication_coordinator_test_fixture.cpp index cd6d0e60381..7fae176024f 100644 --- a/src/mongo/db/repl/replication_coordinator_test_fixture.cpp +++ b/src/mongo/db/repl/replication_coordinator_test_fixture.cpp @@ -438,7 +438,7 @@ void ReplCoordTest::simulateSuccessfulV1ElectionAt(Date_t electionTime) { void ReplCoordTest::signalDrainComplete(OperationContext* opCtx) noexcept { // Writes that occur in code paths that call signalDrainComplete are expected to have Immediate // priority. - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kImmediate); getExternalState()->setFirstOpTimeOfMyTerm(OpTime(Timestamp(1, 1), getReplCoord()->getTerm())); getReplCoord()->signalDrainComplete(opCtx, getReplCoord()->getTerm()); } diff --git a/src/mongo/db/repl/replication_info.cpp b/src/mongo/db/repl/replication_info.cpp index ed400698279..020cd913348 100644 --- a/src/mongo/db/repl/replication_info.cpp +++ b/src/mongo/db/repl/replication_info.cpp @@ -316,6 +316,10 @@ public: return true; } + bool allowedWithSecurityToken() const final { + return true; + } + AllowedOnSecondary secondaryAllowed(ServiceContext*) const final { return AllowedOnSecondary::kAlways; } diff --git a/src/mongo/db/repl/storage_interface_impl.cpp b/src/mongo/db/repl/storage_interface_impl.cpp index 24cd1419970..b1697942c24 100644 --- a/src/mongo/db/repl/storage_interface_impl.cpp +++ b/src/mongo/db/repl/storage_interface_impl.cpp @@ -1483,9 +1483,8 @@ Status StorageInterfaceImpl::isAdminDbValid(OperationContext* opCtx) { void StorageInterfaceImpl::waitForAllEarlierOplogWritesToBeVisible(OperationContext* opCtx, bool primaryOnly) { // Waiting for oplog writes to be visible in the oplog does not use any storage engine resources - // and must skip ticket acquisition to avoid deadlocks with updating oplog visibility. - SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx, - AdmissionContext::Priority::kImmediate); + // and must not wait for ticket acquisition to avoid deadlocks with updating oplog visibility. + SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate); AutoGetOplog oplogRead(opCtx, OplogAccessMode::kRead); if (primaryOnly && @@ -1501,8 +1500,7 @@ void StorageInterfaceImpl::oplogDiskLocRegister(OperationContext* opCtx, bool orderedCommit) { // Setting the oplog visibility does not use any storage engine resources and must skip ticket // acquisition to avoid deadlocks with updating oplog visibility. - SetTicketAquisitionPriorityForLock setTicketAquisition(opCtx, - AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock setTicketAquisition(opCtx, AdmissionContext::Priority::kImmediate); AutoGetOplog oplogRead(opCtx, OplogAccessMode::kRead); fassert(28557, diff --git a/src/mongo/db/repl/tenant_migration_donor_service.cpp b/src/mongo/db/repl/tenant_migration_donor_service.cpp index 9470ac2b036..920538fadd6 100644 --- a/src/mongo/db/repl/tenant_migration_donor_service.cpp +++ b/src/mongo/db/repl/tenant_migration_donor_service.cpp @@ -30,6 +30,7 @@ #include "mongo/db/repl/tenant_migration_donor_service.h" +#include "mongo/client/async_remote_command_targeter_adapter.h" #include "mongo/client/connection_string.h" #include "mongo/client/replica_set_monitor.h" #include "mongo/config.h" @@ -49,6 +50,8 @@ #include "mongo/db/repl/tenant_migration_state_machine_gen.h" #include "mongo/db/repl/tenant_migration_statistics.h" #include "mongo/db/repl/wait_for_majority_service.h" +#include "mongo/executor/async_rpc.h" +#include "mongo/executor/async_rpc_retry_policy.h" #include "mongo/executor/connection_pool.h" #include "mongo/executor/network_interface_factory.h" #include "mongo/logv2/log.h" @@ -91,25 +94,50 @@ const ReadPreferenceSetting kPrimaryOnlyReadPreference(ReadPreference::PrimaryOn const int kMaxRecipientKeyDocsFindAttempts = 10; -bool shouldStopSendingRecipientForgetMigrationCommand(Status status) { - return status.isOK() || - !(ErrorCodes::isRetriableError(status) || ErrorCodes::isNetworkTimeoutError(status) || - // Returned if findHost() is unable to target the recipient in 15 seconds, which may - // happen after a failover. - status == ErrorCodes::FailedToSatisfyReadPreference || - ErrorCodes::isInterruption(status)); -} +/** + * Encapsulates the retry logic for sending the ForgetMigration command. + */ +class RecipientForgetMigrationRetryPolicy + : public async_rpc::RetryWithBackoffOnErrorCategories<ErrorCategory::RetriableError, + ErrorCategory::NetworkTimeoutError, + ErrorCategory::Interruption> { +public: + using RetryWithBackoffOnErrorCategories::RetryWithBackoffOnErrorCategories; + bool recordAndEvaluateRetry(Status status) override { + if (status.isOK()) { + return false; + } + auto underlyingError = async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status); + // Returned if findHost() is unable to target the recipient in 15 seconds, which may + // happen after a failover. + return RetryWithBackoffOnErrorCategories::recordAndEvaluateRetry(underlyingError) || + underlyingError == ErrorCodes::FailedToSatisfyReadPreference; + } +}; -bool shouldStopSendingRecipientSyncDataCommand(Status status, MigrationProtocolEnum protocol) { - if (status.isOK() || protocol == MigrationProtocolEnum::kShardMerge) { - return true; +/** + * Encapsulates the retry logic for sending the SyncData command. + */ +class RecipientSyncDataRetryPolicy + : public async_rpc::RetryWithBackoffOnErrorCategories<ErrorCategory::RetriableError, + ErrorCategory::NetworkTimeoutError> { +public: + RecipientSyncDataRetryPolicy(MigrationProtocolEnum p, Backoff b) + : RetryWithBackoffOnErrorCategories(b), _protocol{p} {} + + /** Returns true if we should retry sending SyncData given the error */ + bool recordAndEvaluateRetry(Status status) { + if (_protocol == MigrationProtocolEnum::kShardMerge || status.isOK()) { + return false; + } + auto underlyingError = async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status); + return RetryWithBackoffOnErrorCategories::recordAndEvaluateRetry(status) || + underlyingError == ErrorCodes::FailedToSatisfyReadPreference; } - return !(ErrorCodes::isRetriableError(status) || ErrorCodes::isNetworkTimeoutError(status) || - // Returned if findHost() is unable to target the recipient in 15 seconds, which may - // happen after a failover. - status == ErrorCodes::FailedToSatisfyReadPreference); -} +private: + MigrationProtocolEnum _protocol; +}; bool shouldStopFetchingRecipientClusterTimeKeyDocs(Status status) { return status.isOK() || @@ -744,90 +772,48 @@ ExecutorFuture<void> TenantMigrationDonorService::Instance::_waitForMajorityWrit }); } -ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendCommandToRecipient( - std::shared_ptr<executor::ScopedTaskExecutor> executor, - std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS, - const BSONObj& cmdObj, - const CancellationToken& token) { - const bool isRecipientSyncDataCmd = cmdObj.hasField(RecipientSyncData::kCommandName); - return AsyncTry( - [this, self = shared_from_this(), executor, recipientTargeterRS, cmdObj, token] { - return recipientTargeterRS->findHost(kPrimaryOnlyReadPreference, token) - .thenRunOn(**executor) - .then([this, self = shared_from_this(), executor, cmdObj, token]( - auto recipientHost) { - executor::RemoteCommandRequest request( - std::move(recipientHost), - NamespaceString::kAdminDb.toString(), - std::move(cmdObj), - rpc::makeEmptyMetadata(), - nullptr); - request.sslMode = _sslMode; - - return (_recipientCmdExecutor) - ->scheduleRemoteCommand(std::move(request), token) - .then([this, - self = shared_from_this()](const auto& response) -> Status { - if (!response.isOK()) { - return response.status; - } - auto commandStatus = getStatusFromCommandResult(response.data); - commandStatus.addContext( - "Tenant migration recipient command failed"); - return commandStatus; - }); - }); - }) - .until([this, self = shared_from_this(), token, cmdObj, isRecipientSyncDataCmd]( - Status status) { - if (isRecipientSyncDataCmd) { - return shouldStopSendingRecipientSyncDataCommand(status, getProtocol()); - } else { - // If the recipient command is not 'recipientSyncData', it must be - // 'recipientForgetMigration'. - invariant(cmdObj.hasField(RecipientForgetMigration::kCommandName)); - return shouldStopSendingRecipientForgetMigrationCommand(status); - } - }) - .withBackoffBetweenIterations(kExponentialBackoff) - .on(**executor, token); -} - ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientSyncDataCommand( - std::shared_ptr<executor::ScopedTaskExecutor> executor, + std::shared_ptr<executor::ScopedTaskExecutor> exec, std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS, const CancellationToken& token) { + auto donorConnString = + repl::ReplicationCoordinator::get(_serviceContext)->getConfigConnectionString(); - const auto cmdObj = [&] { - auto donorConnString = - repl::ReplicationCoordinator::get(_serviceContext)->getConfigConnectionString(); - - RecipientSyncData request; - request.setDbName(NamespaceString::kAdminDb); + RecipientSyncData request; + request.setDbName(NamespaceString::kAdminDb); - MigrationRecipientCommonData commonData( - _migrationUuid, donorConnString.toString(), _readPreference); - commonData.setRecipientCertificateForDonor(_recipientCertificateForDonor); - if (_protocol == MigrationProtocolEnum::kMultitenantMigrations) { - commonData.setTenantId(boost::optional<StringData>(_tenantId)); - } + MigrationRecipientCommonData commonData( + _migrationUuid, donorConnString.toString(), _readPreference); + commonData.setRecipientCertificateForDonor(_recipientCertificateForDonor); + if (_protocol == MigrationProtocolEnum::kMultitenantMigrations) { + commonData.setTenantId(boost::optional<StringData>(_tenantId)); + } - stdx::lock_guard<Latch> lg(_mutex); - commonData.setProtocol(_protocol); - request.setMigrationRecipientCommonData(commonData); + commonData.setProtocol(_protocol); + request.setMigrationRecipientCommonData(commonData); + { + stdx::lock_guard<Latch> lg(_mutex); invariant(_stateDoc.getStartMigrationDonorTimestamp()); request.setStartMigrationDonorTimestamp(*_stateDoc.getStartMigrationDonorTimestamp()); request.setReturnAfterReachingDonorTimestamp(_stateDoc.getBlockTimestamp()); - return request.toBSON(BSONObj()); - }(); + } - return _sendCommandToRecipient(executor, recipientTargeterRS, cmdObj, token); + auto asyncTargeter = std::make_unique<async_rpc::AsyncRemoteCommandTargeterAdapter>( + kPrimaryOnlyReadPreference, recipientTargeterRS); + auto retryPolicy = + std::make_shared<RecipientSyncDataRetryPolicy>(getProtocol(), kExponentialBackoff); + auto cmdRes = async_rpc::sendCommand( + request, _serviceContext, std::move(asyncTargeter), **exec, token, retryPolicy); + return std::move(cmdRes).ignoreValue().onError([](Status status) { + return async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status).addContext( + "Tenant migration recipient command failed"); + }); } ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientForgetMigrationCommand( - std::shared_ptr<executor::ScopedTaskExecutor> executor, + std::shared_ptr<executor::ScopedTaskExecutor> exec, std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS, const CancellationToken& token) { @@ -847,7 +833,15 @@ ExecutorFuture<void> TenantMigrationDonorService::Instance::_sendRecipientForget commonData.setProtocol(_protocol); request.setMigrationRecipientCommonData(commonData); - return _sendCommandToRecipient(executor, recipientTargeterRS, request.toBSON(BSONObj()), token); + auto asyncTargeter = std::make_unique<async_rpc::AsyncRemoteCommandTargeterAdapter>( + kPrimaryOnlyReadPreference, recipientTargeterRS); + auto retryPolicy = std::make_shared<RecipientForgetMigrationRetryPolicy>(kExponentialBackoff); + auto cmdRes = async_rpc::sendCommand( + request, _serviceContext, std::move(asyncTargeter), **exec, token, retryPolicy); + return std::move(cmdRes).ignoreValue().onError([](Status status) { + return async_rpc::unpackRPCStatusIgnoringWriteConcernAndWriteErrors(status).addContext( + "Tenant migration recipient command failed"); + }); } CancellationToken TenantMigrationDonorService::Instance::_initAbortMigrationSource( diff --git a/src/mongo/db/repl/tenant_migration_donor_service.h b/src/mongo/db/repl/tenant_migration_donor_service.h index c9f28ddcc07..f8b0db86e17 100644 --- a/src/mongo/db/repl/tenant_migration_donor_service.h +++ b/src/mongo/db/repl/tenant_migration_donor_service.h @@ -255,15 +255,6 @@ public: const CancellationToken& token); /** - * Sends the given command to the recipient replica set. - */ - ExecutorFuture<void> _sendCommandToRecipient( - std::shared_ptr<executor::ScopedTaskExecutor> executor, - std::shared_ptr<RemoteCommandTargeter> recipientTargeterRS, - const BSONObj& cmdObj, - const CancellationToken& token); - - /** * Sends the recipientSyncData command to the recipient replica set. */ ExecutorFuture<void> _sendRecipientSyncDataCommand( diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.cpp b/src/mongo/db/repl/tenant_migration_recipient_service.cpp index 57ec62819a1..c6c4d4d2958 100644 --- a/src/mongo/db/repl/tenant_migration_recipient_service.cpp +++ b/src/mongo/db/repl/tenant_migration_recipient_service.cpp @@ -76,6 +76,7 @@ #include "mongo/db/transaction/transaction_participant.h" #include "mongo/db/vector_clock_mutable.h" #include "mongo/db/write_concern_options.h" +#include "mongo/executor/task_executor.h" #include "mongo/logv2/log.h" #include "mongo/rpc/get_status_from_command_result.h" #include "mongo/util/assert_util.h" @@ -93,6 +94,8 @@ const std::string kTTLIndexName = "TenantMigrationRecipientTTLIndex"; const Backoff kExponentialBackoff(Seconds(1), Milliseconds::max()); constexpr StringData kOplogBufferPrefix = "repl.migration.oplog_"_sd; constexpr int kBackupCursorFileFetcherRetryAttempts = 10; +constexpr int kCheckpointTsBackupCursorErrorCode = 6929900; +constexpr int kCloseCursorBeforeOpenErrorCode = 50886; NamespaceString getOplogBufferNs(const UUID& migrationUUID) { return NamespaceString(NamespaceString::kConfigDb, @@ -979,25 +982,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_killBackupCursor() nullptr); request.sslMode = _donorUri.getSSLMode(); - auto scheduleResult = - (_recipientService->getInstanceCleanupExecutor()) - ->scheduleRemoteCommand( - request, [](const executor::TaskExecutor::RemoteCommandCallbackArgs& args) { - if (!args.response.isOK()) { - LOGV2_WARNING(6113005, - "killCursors command task failed", - "error"_attr = redact(args.response.status)); - return; - } - auto status = getStatusFromCommandResult(args.response.data); - if (status.isOK()) { - LOGV2_INFO(6113415, "Killed backup cursor"); - } else { - LOGV2_WARNING(6113006, - "killCursors command failed", - "error"_attr = redact(status)); - } - }); + const auto scheduleResult = _scheduleKillBackupCursorWithLock( + lk, _recipientService->getInstanceCleanupExecutor()); if (!scheduleResult.isOK()) { LOGV2_WARNING(6113004, "Failed to run killCursors command on backup cursor", @@ -1009,13 +995,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_killBackupCursor() SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( const CancellationToken& token) { - stdx::lock_guard lk(_mutex); - LOGV2_DEBUG(6113000, - 1, - "Trying to open backup cursor on donor primary", - "migrationId"_attr = _stateDoc.getId(), - "donorConnectionString"_attr = _stateDoc.getDonorConnectionString()); - const auto cmdObj = [] { + + const auto aggregateCommandRequestObj = [] { AggregateCommandRequest aggRequest( NamespaceString::makeCollectionlessAggregateNSS(NamespaceString::kAdminDb), {BSON("$backupCursor" << BSONObj())}); @@ -1024,11 +1005,18 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( return aggRequest.toBSON(BSONObj()); }(); - auto startMigrationDonorTimestamp = _stateDoc.getStartMigrationDonorTimestamp(); + stdx::lock_guard lk(_mutex); + LOGV2_DEBUG(6113000, + 1, + "Trying to open backup cursor on donor primary", + "migrationId"_attr = _stateDoc.getId(), + "donorConnectionString"_attr = _stateDoc.getDonorConnectionString()); + + const auto startMigrationDonorTimestamp = _stateDoc.getStartMigrationDonorTimestamp(); auto fetchStatus = std::make_shared<boost::optional<Status>>(); auto uniqueMetadataInfo = std::make_unique<boost::optional<shard_merge_utils::MetadataInfo>>(); - auto fetcherCallback = + const auto fetcherCallback = [ this, self = shared_from_this(), @@ -1043,8 +1031,8 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( uassertStatusOK(dataStatus); uassert(ErrorCodes::CallbackCanceled, "backup cursor interrupted", !token.isCanceled()); - auto uniqueOpCtx = cc().makeOperationContext(); - auto opCtx = uniqueOpCtx.get(); + const auto uniqueOpCtx = cc().makeOperationContext(); + const auto opCtx = uniqueOpCtx.get(); const auto& data = dataStatus.getValue(); for (const BSONObj& doc : data.documents) { @@ -1059,14 +1047,6 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( "backupCursorId"_attr = data.cursorId, "backupCursorCheckpointTimestamp"_attr = checkpointTimestamp); - // This ensures that the recipient won’t receive any 2 phase index build donor - // oplog entries during the migration. We also have a check in the tenant oplog - // applier to detect such oplog entries. Adding a check here helps us to detect - // the problem earlier. - uassert(6929900, - "backupCursorCheckpointTimestamp should be greater than or equal to " - "startMigrationDonorTimestamp", - checkpointTimestamp >= startMigrationDonorTimestamp); { stdx::lock_guard lk(_mutex); stdx::lock_guard<TenantMigrationSharedData> sharedDatalk(*_sharedData); @@ -1075,6 +1055,15 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( BackupCursorInfo{data.cursorId, data.nss, checkpointTimestamp}); } + // This ensures that the recipient won’t receive any 2 phase index build donor + // oplog entries during the migration. We also have a check in the tenant oplog + // applier to detect such oplog entries. Adding a check here helps us to detect + // the problem earlier. + uassert(kCheckpointTsBackupCursorErrorCode, + "backupCursorCheckpointTimestamp should be greater than or equal to " + "startMigrationDonorTimestamp", + checkpointTimestamp >= startMigrationDonorTimestamp); + invariant(metadataInfoPtr && !*metadataInfoPtr); (*metadataInfoPtr) = shard_merge_utils::MetadataInfo::constructMetadataInfo( getMigrationUUID(), _client->getServerAddress(), metadata); @@ -1133,10 +1122,10 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( }; _donorFilenameBackupCursorFileFetcher = std::make_unique<Fetcher>( - (**_scopedExecutor).get(), + _backupCursorExecutor.get(), _client->getServerHostAndPort(), NamespaceString::kAdminDb.toString(), - cmdObj, + aggregateCommandRequestObj, fetcherCallback, ReadPreferenceSetting(ReadPreference::PrimaryPreferred).toContainingBSON(), executor::RemoteCommandRequest::kNoTimeout, /* aggregateTimeout */ @@ -1160,6 +1149,35 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursor( .semi(); } +StatusWith<executor::TaskExecutor::CallbackHandle> +TenantMigrationRecipientService::Instance::_scheduleKillBackupCursorWithLock( + WithLock lk, std::shared_ptr<executor::TaskExecutor> executor) { + auto& donorBackupCursorInfo = _getDonorBackupCursorInfo(lk); + executor::RemoteCommandRequest killCursorsRequest( + _client->getServerHostAndPort(), + donorBackupCursorInfo.nss.db().toString(), + BSON("killCursors" << donorBackupCursorInfo.nss.coll().toString() << "cursors" + << BSON_ARRAY(donorBackupCursorInfo.cursorId)), + nullptr); + killCursorsRequest.sslMode = _donorUri.getSSLMode(); + + return executor->scheduleRemoteCommand( + killCursorsRequest, [](const executor::TaskExecutor::RemoteCommandCallbackArgs& args) { + if (!args.response.isOK()) { + LOGV2_WARNING(6113005, + "killCursors command task failed", + "error"_attr = redact(args.response.status)); + return; + } + auto status = getStatusFromCommandResult(args.response.data); + if (status.isOK()) { + LOGV2_INFO(6113415, "Killed backup cursor"); + } else { + LOGV2_WARNING(6113006, "killCursors command failed", "error"_attr = redact(status)); + } + }); +} + SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursorWithRetry( const CancellationToken& token) { return AsyncTry([this, self = shared_from_this(), token] { return _openBackupCursor(token); }) @@ -1169,8 +1187,20 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::_openBackupCursorWit "Retrying backup cursor creation after transient error", "migrationId"_attr = getMigrationUUID(), "status"_attr = status); - // A checkpoint took place while opening a backup cursor. We - // should retry and *not* cancel migration. + + return false; + } else if (status.code() == kCheckpointTsBackupCursorErrorCode || + status.code() == kCloseCursorBeforeOpenErrorCode) { + LOGV2_INFO(6955100, + "Closing backup cursor and retrying after getting retryable error", + "migrationId"_attr = getMigrationUUID(), + "status"_attr = status); + + stdx::lock_guard lk(_mutex); + const auto scheduleResult = + _scheduleKillBackupCursorWithLock(lk, _backupCursorExecutor); + uassertStatusOK(scheduleResult); + return false; } @@ -1199,7 +1229,7 @@ void TenantMigrationRecipientService::Instance::_keepBackupCursorAlive( auto& donorBackupCursorInfo = _getDonorBackupCursorInfo(lk); _backupCursorKeepAliveFuture = shard_merge_utils::keepBackupCursorAlive(_backupCursorKeepAliveCancellation, - **_scopedExecutor, + _backupCursorExecutor, _client->getServerHostAndPort(), donorBackupCursorInfo.cursorId, donorBackupCursorInfo.nss); @@ -2703,6 +2733,7 @@ SemiFuture<void> TenantMigrationRecipientService::Instance::run( std::shared_ptr<executor::ScopedTaskExecutor> executor, const CancellationToken& token) noexcept { _scopedExecutor = executor; + _backupCursorExecutor = **_scopedExecutor; auto scopedOutstandingMigrationCounter = TenantMigrationStatistics::get(_serviceContext)->getScopedOutstandingReceivingCount(); diff --git a/src/mongo/db/repl/tenant_migration_recipient_service.h b/src/mongo/db/repl/tenant_migration_recipient_service.h index 22de9a00fd1..a71f29139d9 100644 --- a/src/mongo/db/repl/tenant_migration_recipient_service.h +++ b/src/mongo/db/repl/tenant_migration_recipient_service.h @@ -213,6 +213,15 @@ public: private: friend class TenantMigrationRecipientServiceTest; + friend class TenantMigrationRecipientServiceShardMergeTest; + + /** + * Only used for testing. Allows setting a custom task executor for backup cursor fetcher. + */ + void setBackupCursorFetcherExecutor_forTest( + std::shared_ptr<executor::TaskExecutor> taskExecutor) { + _backupCursorExecutor = taskExecutor; + } const NamespaceString _stateDocumentsNS = NamespaceString::kTenantMigrationRecipientsNamespace; @@ -605,6 +614,13 @@ public: SemiFuture<TenantOplogApplier::OpTimePair> _migrateUsingShardMergeProtocol( const CancellationToken& token); + /* + * Send the killBackupCursor command to the remote in order to close the backup cursor + * connection on the donor. + */ + StatusWith<executor::TaskExecutor::CallbackHandle> _scheduleKillBackupCursorWithLock( + WithLock lk, std::shared_ptr<executor::TaskExecutor> executor); + mutable Mutex _mutex = MONGO_MAKE_LATCH("TenantMigrationRecipientService::_mutex"); // All member variables are labeled with one of the following codes indicating the @@ -618,6 +634,7 @@ public: ServiceContext* const _serviceContext; const TenantMigrationRecipientService* const _recipientService; // (R) (not owned) std::shared_ptr<executor::ScopedTaskExecutor> _scopedExecutor; // (M) + std::shared_ptr<executor::TaskExecutor> _backupCursorExecutor; // (M) TenantMigrationRecipientDocument _stateDoc; // (M) // This data is provided in the initial state doc and never changes. We keep copies to diff --git a/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp b/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp new file mode 100644 index 00000000000..be6f0bbc799 --- /dev/null +++ b/src/mongo/db/repl/tenant_migration_recipient_service_shard_merge_test.cpp @@ -0,0 +1,593 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include <fstream> +#include <memory> + +#include "mongo/client/connpool.h" +#include "mongo/client/replica_set_monitor.h" +#include "mongo/client/replica_set_monitor_protocol_test_util.h" +#include "mongo/client/streamable_replica_set_monitor_for_testing.h" +#include "mongo/config.h" +#include "mongo/db/client.h" +#include "mongo/db/db_raii.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/feature_compatibility_version_document_gen.h" +#include "mongo/db/op_observer/op_observer_impl.h" +#include "mongo/db/op_observer/op_observer_registry.h" +#include "mongo/db/op_observer/oplog_writer_impl.h" +#include "mongo/db/repl/drop_pending_collection_reaper.h" +#include "mongo/db/repl/oplog.h" +#include "mongo/db/repl/oplog_buffer_collection.h" +#include "mongo/db/repl/oplog_fetcher_mock.h" +#include "mongo/db/repl/primary_only_service.h" +#include "mongo/db/repl/primary_only_service_op_observer.h" +#include "mongo/db/repl/replication_coordinator_mock.h" +#include "mongo/db/repl/storage_interface_impl.h" +#include "mongo/db/repl/tenant_migration_recipient_entry_helpers.h" +#include "mongo/db/repl/tenant_migration_recipient_service.h" +#include "mongo/db/repl/tenant_migration_state_machine_gen.h" +#include "mongo/db/repl/wait_for_majority_service.h" +#include "mongo/db/service_context_d_test_fixture.h" +#include "mongo/db/session/session_txn_record_gen.h" +#include "mongo/db/storage/backup_cursor_hooks.h" +#include "mongo/dbtests/mock/mock_conn_registry.h" +#include "mongo/dbtests/mock/mock_replica_set.h" +#include "mongo/executor/mock_network_fixture.h" +#include "mongo/executor/network_interface.h" +#include "mongo/executor/network_interface_mock.h" +#include "mongo/executor/thread_pool_mock.h" +#include "mongo/executor/thread_pool_task_executor.h" +#include "mongo/executor/thread_pool_task_executor_test_fixture.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/logv2/log.h" +#include "mongo/rpc/metadata/egress_metadata_hook_list.h" +#include "mongo/transport/transport_layer_manager.h" +#include "mongo/transport/transport_layer_mock.h" +#include "mongo/unittest/log_test.h" +#include "mongo/unittest/unittest.h" +#include "mongo/util/clock_source_mock.h" +#include "mongo/util/concurrency/thread_pool.h" +#include "mongo/util/fail_point.h" +#include "mongo/util/future.h" +#include "mongo/util/net/ssl_util.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest + + +namespace mongo { +namespace repl { + +namespace { +constexpr std::int32_t stopFailPointErrorCode = 4880402; +const Timestamp kDefaultStartMigrationTimestamp(1, 1); + +OplogEntry makeOplogEntry(OpTime opTime, + OpTypeEnum opType, + NamespaceString nss, + const boost::optional<UUID>& uuid, + BSONObj o, + boost::optional<BSONObj> o2) { + return {DurableOplogEntry(opTime, // optime + opType, // opType + nss, // namespace + uuid, // uuid + boost::none, // fromMigrate + OplogEntry::kOplogVersion, // version + o, // o + o2, // o2 + {}, // sessionInfo + boost::none, // upsert + Date_t(), // wall clock time + {}, // statement ids + boost::none, // optime of previous write within same transaction + boost::none, // pre-image optime + boost::none, // post-image optime + boost::none, // ShardId of resharding recipient + boost::none, // _id + boost::none)}; // needsRetryImage +} + +} // namespace + +class TenantMigrationRecipientServiceShardMergeTest : public ServiceContextMongoDTest { +public: + class stopFailPointEnableBlock : public FailPointEnableBlock { + public: + explicit stopFailPointEnableBlock(StringData failPointName, + std::int32_t error = stopFailPointErrorCode) + : FailPointEnableBlock(failPointName, + BSON("action" + << "stop" + << "stopErrorCode" << error)) {} + }; + + void setUp() override { + ServiceContextMongoDTest::setUp(); + auto serviceContext = getServiceContext(); + + // Fake replSet just for creating consistent URI for monitor + MockReplicaSet replSet("donorSet", 1, true /* hasPrimary */, true /* dollarPrefixHosts */); + _rsmMonitor.setup(replSet.getURI()); + + ConnectionString::setConnectionHook(mongo::MockConnRegistry::get()->getConnStrHook()); + + WaitForMajorityService::get(serviceContext).startup(serviceContext); + + // Automatically mark the state doc garbage collectable after data sync completion. + globalFailPointRegistry() + .find("autoRecipientForgetMigration") + ->setMode(FailPoint::alwaysOn); + + { + auto opCtx = cc().makeOperationContext(); + auto replCoord = std::make_unique<ReplicationCoordinatorMock>(serviceContext); + ReplicationCoordinator::set(serviceContext, std::move(replCoord)); + + repl::createOplog(opCtx.get()); + { + Lock::GlobalWrite lk(opCtx.get()); + OldClientContext ctx(opCtx.get(), NamespaceString::kRsOplogNamespace); + tenant_migration_util::createOplogViewForTenantMigrations(opCtx.get(), ctx.db()); + } + + // Need real (non-mock) storage for the oplog buffer. + StorageInterface::set(serviceContext, std::make_unique<StorageInterfaceImpl>()); + + // The DropPendingCollectionReaper is required to drop the oplog buffer collection. + repl::DropPendingCollectionReaper::set( + serviceContext, + std::make_unique<repl::DropPendingCollectionReaper>( + StorageInterface::get(serviceContext))); + + // Set up OpObserver so that repl::logOp() will store the oplog entry's optime in + // ReplClientInfo. + OpObserverRegistry* opObserverRegistry = + dynamic_cast<OpObserverRegistry*>(serviceContext->getOpObserver()); + opObserverRegistry->addObserver( + std::make_unique<OpObserverImpl>(std::make_unique<OplogWriterImpl>())); + opObserverRegistry->addObserver( + std::make_unique<PrimaryOnlyServiceOpObserver>(serviceContext)); + + _registry = repl::PrimaryOnlyServiceRegistry::get(getServiceContext()); + std::unique_ptr<TenantMigrationRecipientService> service = + std::make_unique<TenantMigrationRecipientService>(getServiceContext()); + _registry->registerService(std::move(service)); + _registry->onStartup(opCtx.get()); + } + stepUp(); + + _service = _registry->lookupServiceByName( + TenantMigrationRecipientService::kTenantMigrationRecipientServiceName); + ASSERT(_service); + + // MockReplicaSet uses custom connection string which does not support auth. + auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth"); + authFp->setMode(FailPoint::alwaysOn); + + // Set the sslMode to allowSSL to avoid validation error. + sslGlobalParams.sslMode.store(SSLParams::SSLMode_allowSSL); + // Skipped unless tested explicitly, as we will not receive an FCV document from the donor + // in these unittests without (unsightly) intervention. + auto compFp = globalFailPointRegistry().find("skipComparingRecipientAndDonorFCV"); + compFp->setMode(FailPoint::alwaysOn); + + // Skip fetching retryable writes, as we will test this logic entirely in integration + // tests. + auto fetchRetryableWritesFp = + globalFailPointRegistry().find("skipFetchingRetryableWritesEntriesBeforeStartOpTime"); + fetchRetryableWritesFp->setMode(FailPoint::alwaysOn); + + // Skip fetching committed transactions, as we will test this logic entirely in integration + // tests. + auto fetchCommittedTransactionsFp = + globalFailPointRegistry().find("skipFetchingCommittedTransactions"); + fetchCommittedTransactionsFp->setMode(FailPoint::alwaysOn); + + // setup mock networking that will be use to mock the backup cursor traffic. + auto net = std::make_unique<executor::NetworkInterfaceMock>(); + _net = net.get(); + + executor::ThreadPoolMock::Options dbThreadPoolOptions; + dbThreadPoolOptions.onCreateThread = []() { Client::initThread("FetchMockTaskExecutor"); }; + + auto pool = std::make_unique<executor::ThreadPoolMock>(_net, 1, dbThreadPoolOptions); + _threadpoolTaskExecutor = + std::make_shared<executor::ThreadPoolTaskExecutor>(std::move(pool), std::move(net)); + _threadpoolTaskExecutor->startup(); + } + + void tearDown() override { + _threadpoolTaskExecutor->shutdown(); + _threadpoolTaskExecutor->join(); + + auto authFp = globalFailPointRegistry().find("skipTenantMigrationRecipientAuth"); + authFp->setMode(FailPoint::off); + + // Unset the sslMode. + sslGlobalParams.sslMode.store(SSLParams::SSLMode_disabled); + + WaitForMajorityService::get(getServiceContext()).shutDown(); + + _registry->onShutdown(); + _service = nullptr; + + StorageInterface::set(getServiceContext(), {}); + + // Clearing the connection pool is necessary when doing tests which use the + // ReplicaSetMonitor. See src/mongo/dbtests/mock/mock_replica_set.h for details. + ScopedDbConnection::clearPool(); + ReplicaSetMonitorProtocolTestUtil::resetRSMProtocol(); + ServiceContextMongoDTest::tearDown(); + } + + void stepDown() { + ASSERT_OK(ReplicationCoordinator::get(getServiceContext()) + ->setFollowerMode(MemberState::RS_SECONDARY)); + _registry->onStepDown(); + } + + void stepUp() { + auto opCtx = cc().makeOperationContext(); + auto replCoord = ReplicationCoordinator::get(getServiceContext()); + + // Advance term + _term++; + + ASSERT_OK(replCoord->setFollowerMode(MemberState::RS_PRIMARY)); + ASSERT_OK(replCoord->updateTerm(opCtx.get(), _term)); + replCoord->setMyLastAppliedOpTimeAndWallTime( + OpTimeAndWallTime(OpTime(Timestamp(1, 1), _term), Date_t())); + + _registry->onStepUpComplete(opCtx.get(), _term); + } + +protected: + TenantMigrationRecipientServiceShardMergeTest() + : ServiceContextMongoDTest(Options{}.useMockClock(true)) {} + + PrimaryOnlyServiceRegistry* _registry; + PrimaryOnlyService* _service; + long long _term = 0; + + bool _collCreated = false; + size_t _numSecondaryIndexesCreated{0}; + size_t _numDocsInserted{0}; + + const TenantMigrationPEMPayload kRecipientPEMPayload = [&] { + std::ifstream infile("jstests/libs/client.pem"); + std::string buf((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>()); + + auto swCertificateBlob = + ssl_util::findPEMBlob(buf, "CERTIFICATE"_sd, 0 /* position */, false /* allowEmpty */); + ASSERT_TRUE(swCertificateBlob.isOK()); + + auto swPrivateKeyBlob = + ssl_util::findPEMBlob(buf, "PRIVATE KEY"_sd, 0 /* position */, false /* allowEmpty */); + ASSERT_TRUE(swPrivateKeyBlob.isOK()); + + return TenantMigrationPEMPayload{swCertificateBlob.getValue().toString(), + swPrivateKeyBlob.getValue().toString()}; + }(); + + void checkStateDocPersisted(OperationContext* opCtx, + const TenantMigrationRecipientService::Instance* instance) { + auto memoryStateDoc = getStateDoc(instance); + auto persistedStateDocWithStatus = + tenantMigrationRecipientEntryHelpers::getStateDoc(opCtx, memoryStateDoc.getId()); + ASSERT_OK(persistedStateDocWithStatus.getStatus()); + ASSERT_BSONOBJ_EQ(memoryStateDoc.toBSON(), persistedStateDocWithStatus.getValue().toBSON()); + } + void insertToNodes(MockReplicaSet* replSet, + const std::string& nss, + BSONObj obj, + const std::vector<HostAndPort>& hosts) { + for (const auto& host : hosts) { + replSet->getNode(host.toString())->insert(nss, obj); + } + } + + void clearCollection(MockReplicaSet* replSet, + const std::string& nss, + const std::vector<HostAndPort>& hosts) { + for (const auto& host : hosts) { + replSet->getNode(host.toString())->remove(nss, BSONObj{} /*filter*/); + } + } + + void insertTopOfOplog(MockReplicaSet* replSet, + const OpTime& topOfOplogOpTime, + const std::vector<HostAndPort> hosts = {}) { + const auto targetHosts = hosts.empty() ? replSet->getHosts() : hosts; + // The MockRemoteDBService does not actually implement the database, so to make our + // find work correctly we must make sure there's only one document to find. + clearCollection(replSet, NamespaceString::kRsOplogNamespace.ns(), targetHosts); + insertToNodes(replSet, + NamespaceString::kRsOplogNamespace.ns(), + makeOplogEntry(topOfOplogOpTime, + OpTypeEnum::kNoop, + {} /* namespace */, + boost::none /* uuid */, + BSONObj() /* o */, + boost::none /* o2 */) + .getEntry() + .toBSON(), + targetHosts); + } + + // Accessors to class private members + DBClientConnection* getClient(const TenantMigrationRecipientService::Instance* instance) const { + return instance->_client.get(); + } + + const TenantMigrationRecipientDocument& getStateDoc( + const TenantMigrationRecipientService::Instance* instance) const { + return instance->_stateDoc; + } + + sdam::MockTopologyManager* getTopologyManager() { + return _rsmMonitor.getTopologyManager(); + } + + ClockSource* clock() { + return &_clkSource; + } + + executor::NetworkInterfaceMock* getNet() { + return _net; + } + + executor::NetworkInterfaceMock* _net = nullptr; + std::shared_ptr<executor::TaskExecutor> _threadpoolTaskExecutor; + + void setInstanceBackupCursorFetcherExecutor( + std::shared_ptr<TenantMigrationRecipientService::Instance> instance) { + instance->setBackupCursorFetcherExecutor_forTest(_threadpoolTaskExecutor); + } + +private: + ClockSourceMock _clkSource; + + unittest::MinimumLoggedSeverityGuard _replicationSeverityGuard{ + logv2::LogComponent::kReplication, logv2::LogSeverity::Debug(1)}; + unittest::MinimumLoggedSeverityGuard _tenantMigrationSeverityGuard{ + logv2::LogComponent::kTenantMigration, logv2::LogSeverity::Debug(1)}; + + StreamableReplicaSetMonitorForTesting _rsmMonitor; + RAIIServerParameterControllerForTest _findHostTimeout{"defaultFindReplicaSetHostTimeoutMS", 10}; +}; + +#ifdef MONGO_CONFIG_SSL + +void waitForReadyRequest(executor::NetworkInterfaceMock* net) { + while (!net->hasReadyRequests()) { + net->advanceTime(net->now() + Milliseconds{1}); + } +} + +BSONObj createEmptyCursorResponse(const NamespaceString& nss, CursorId backupCursorId) { + return BSON( + "cursor" << BSON("nextBatch" << BSONArray() << "id" << backupCursorId << "ns" << nss.ns()) + << "ok" << 1.0); +} + +BSONObj createBackupCursorResponse(const Timestamp& checkpointTimestamp, + const NamespaceString& nss, + CursorId backupCursorId) { + const UUID backupId = + UUID(uassertStatusOK(UUID::parse(("2b068e03-5961-4d8e-b47a-d1c8cbd4b835")))); + StringData remoteDbPath = "/data/db/job0/mongorunner/test-1"; + BSONObjBuilder cursor; + BSONArrayBuilder batch(cursor.subarrayStart("firstBatch")); + auto metaData = BSON("backupId" << backupId << "checkpointTimestamp" << checkpointTimestamp + << "dbpath" << remoteDbPath); + batch.append(BSON("metadata" << metaData)); + + batch.done(); + cursor.append("id", backupCursorId); + cursor.append("ns", nss.ns()); + BSONObjBuilder backupCursorReply; + backupCursorReply.append("cursor", cursor.obj()); + backupCursorReply.append("ok", 1.0); + return backupCursorReply.obj(); +} + +void sendReponseToExpectedRequest(const BSONObj& backupCursorResponse, + const std::string& expectedRequestFieldName, + executor::NetworkInterfaceMock* net) { + auto noi = net->getNextReadyRequest(); + auto request = noi->getRequest(); + ASSERT_EQUALS(expectedRequestFieldName, request.cmdObj.firstElementFieldNameStringData()); + net->scheduleSuccessfulResponse( + noi, executor::RemoteCommandResponse(backupCursorResponse, Milliseconds())); + net->runReadyNetworkOperations(); +} + +BSONObj createServerAggregateReply() { + return CursorResponse( + NamespaceString::makeCollectionlessAggregateNSS(NamespaceString::kAdminDb), + 0 /* cursorId */, + {BSON("byteOffset" << 0 << "endOfFile" << true << "data" + << BSONBinData(0, 0, BinDataGeneral))}) + .toBSONAsInitialResponse(); +} + +TEST_F(TenantMigrationRecipientServiceShardMergeTest, OpenBackupCursorSuccessfully) { + stopFailPointEnableBlock fp("fpBeforeAdvancingStableTimestamp"); + const UUID migrationUUID = UUID::gen(); + const CursorId backupCursorId = 12345; + const NamespaceString aggregateNs = NamespaceString("admin.$cmd.aggregate"); + + auto taskFp = globalFailPointRegistry().find("hangBeforeTaskCompletion"); + auto initialTimesEntered = taskFp->setMode(FailPoint::alwaysOn); + + MockReplicaSet replSet("donorSet", 3, true /* hasPrimary */, true /* dollarPrefixHosts */); + getTopologyManager()->setTopologyDescription(replSet.getTopologyDescription(clock())); + insertTopOfOplog(&replSet, OpTime(Timestamp(5, 1), 1)); + + // Mock the aggregate response from the donor. + MockRemoteDBServer* const _donorServer = + mongo::MockConnRegistry::get()->getMockRemoteDBServer(replSet.getPrimary()); + _donorServer->setCommandReply("aggregate", createServerAggregateReply()); + + TenantMigrationRecipientDocument initialStateDocument( + migrationUUID, + replSet.getConnectionString(), + "tenantA", + kDefaultStartMigrationTimestamp, + ReadPreferenceSetting(ReadPreference::PrimaryOnly)); + initialStateDocument.setProtocol(MigrationProtocolEnum::kShardMerge); + initialStateDocument.setRecipientCertificateForDonor(kRecipientPEMPayload); + + auto opCtx = makeOperationContext(); + std::shared_ptr<TenantMigrationRecipientService::Instance> instance; + { + auto fp = globalFailPointRegistry().find("pauseBeforeRunTenantMigrationRecipientInstance"); + auto initialTimesEntered = fp->setMode(FailPoint::alwaysOn); + instance = TenantMigrationRecipientService::Instance::getOrCreate( + opCtx.get(), _service, initialStateDocument.toBSON()); + ASSERT(instance.get()); + fp->waitForTimesEntered(initialTimesEntered + 1); + setInstanceBackupCursorFetcherExecutor(instance); + instance->setCreateOplogFetcherFn_forTest(std::make_unique<CreateOplogFetcherMockFn>()); + fp->setMode(FailPoint::off); + } + + { + auto net = getNet(); + executor::NetworkInterfaceMock::InNetworkGuard guard(net); + waitForReadyRequest(net); + // Mocking the aggregate command network response of the backup cursor in order to have + // data to parse. + sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp, + aggregateNs, + backupCursorId), + "aggregate", + net); + sendReponseToExpectedRequest( + createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net); + sendReponseToExpectedRequest( + createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net); + } + + taskFp->waitForTimesEntered(initialTimesEntered + 1); + + checkStateDocPersisted(opCtx.get(), instance.get()); + + taskFp->setMode(FailPoint::off); + + ASSERT_EQ(stopFailPointErrorCode, instance->getDataSyncCompletionFuture().getNoThrow().code()); + ASSERT_OK(instance->getForgetMigrationDurableFuture().getNoThrow()); +} + +TEST_F(TenantMigrationRecipientServiceShardMergeTest, OpenBackupCursorAndRetriesDueToTs) { + stopFailPointEnableBlock fp("fpBeforeAdvancingStableTimestamp"); + const UUID migrationUUID = UUID::gen(); + const CursorId backupCursorId = 12345; + const NamespaceString aggregateNs = NamespaceString("admin.$cmd.aggregate"); + + auto taskFp = globalFailPointRegistry().find("hangBeforeTaskCompletion"); + auto initialTimesEntered = taskFp->setMode(FailPoint::alwaysOn); + + MockReplicaSet replSet("donorSet", 3, true /* hasPrimary */, true /* dollarPrefixHosts */); + getTopologyManager()->setTopologyDescription(replSet.getTopologyDescription(clock())); + insertTopOfOplog(&replSet, OpTime(Timestamp(5, 1), 1)); + + // Mock the aggregate response from the donor. + MockRemoteDBServer* const _donorServer = + mongo::MockConnRegistry::get()->getMockRemoteDBServer(replSet.getPrimary()); + _donorServer->setCommandReply("aggregate", createServerAggregateReply()); + + TenantMigrationRecipientDocument initialStateDocument( + migrationUUID, + replSet.getConnectionString(), + "tenantA", + kDefaultStartMigrationTimestamp, + ReadPreferenceSetting(ReadPreference::PrimaryOnly)); + initialStateDocument.setProtocol(MigrationProtocolEnum::kShardMerge); + initialStateDocument.setRecipientCertificateForDonor(kRecipientPEMPayload); + + auto opCtx = makeOperationContext(); + std::shared_ptr<TenantMigrationRecipientService::Instance> instance; + { + auto fp = globalFailPointRegistry().find("pauseBeforeRunTenantMigrationRecipientInstance"); + auto initialTimesEntered = fp->setMode(FailPoint::alwaysOn); + instance = TenantMigrationRecipientService::Instance::getOrCreate( + opCtx.get(), _service, initialStateDocument.toBSON()); + ASSERT(instance.get()); + fp->waitForTimesEntered(initialTimesEntered + 1); + setInstanceBackupCursorFetcherExecutor(instance); + instance->setCreateOplogFetcherFn_forTest(std::make_unique<CreateOplogFetcherMockFn>()); + fp->setMode(FailPoint::off); + } + + { + auto net = getNet(); + executor::NetworkInterfaceMock::InNetworkGuard guard(net); + waitForReadyRequest(net); + + // Mocking the aggregate command network response of the backup cursor in order to have data + // to parse. In this case we pass a timestamp that is inferior to the + // startMigrationTimestamp which will cause a retry. We then provide a correct timestamp in + // the next response and succeed. + sendReponseToExpectedRequest( + createBackupCursorResponse(Timestamp(0, 0), aggregateNs, backupCursorId), + "aggregate", + net); + sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp, + aggregateNs, + backupCursorId), + "killCursors", + net); + sendReponseToExpectedRequest( + createEmptyCursorResponse(aggregateNs, backupCursorId), "killCursors", net); + sendReponseToExpectedRequest(createBackupCursorResponse(kDefaultStartMigrationTimestamp, + aggregateNs, + backupCursorId), + "aggregate", + net); + sendReponseToExpectedRequest( + createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net); + sendReponseToExpectedRequest( + createEmptyCursorResponse(aggregateNs, backupCursorId), "getMore", net); + } + + taskFp->waitForTimesEntered(initialTimesEntered + 1); + + checkStateDocPersisted(opCtx.get(), instance.get()); + + taskFp->setMode(FailPoint::off); + + ASSERT_EQ(stopFailPointErrorCode, instance->getDataSyncCompletionFuture().getNoThrow().code()); + ASSERT_OK(instance->getForgetMigrationDurableFuture().getNoThrow()); +} + +#endif +} // namespace repl +} // namespace mongo diff --git a/src/mongo/db/s/SConscript b/src/mongo/db/s/SConscript index fa54218910d..3248c1b296d 100644 --- a/src/mongo/db/s/SConscript +++ b/src/mongo/db/s/SConscript @@ -11,7 +11,6 @@ env = env.Clone() env.Library( target='sharding_api_d', source=[ - 'balancer_stats_registry.cpp', 'collection_metadata.cpp', 'collection_sharding_state_factory_standalone.cpp', 'collection_sharding_state.cpp', @@ -36,14 +35,25 @@ env.Library( LIBDEPS_PRIVATE=[ '$BUILD_DIR/mongo/db/catalog/index_catalog', '$BUILD_DIR/mongo/db/concurrency/lock_manager', - '$BUILD_DIR/mongo/db/dbdirectclient', - '$BUILD_DIR/mongo/db/repl/replica_set_aware_service', '$BUILD_DIR/mongo/db/server_base', '$BUILD_DIR/mongo/db/write_block_bypass', ], ) env.Library( + target='balancer_stats_registry', + source=[ + 'balancer_stats_registry.cpp', + ], + LIBDEPS_PRIVATE=[ + '$BUILD_DIR/mongo/db/dbdirectclient', + '$BUILD_DIR/mongo/db/repl/replica_set_aware_service', + '$BUILD_DIR/mongo/db/shard_role', + '$BUILD_DIR/mongo/s/grid', + ], +) + +env.Library( target='sharding_catalog', source=[ 'global_index_ddl_util.cpp', @@ -64,6 +74,22 @@ env.Library( ) env.Library( + target="query_analysis_writer", + source=[ + "query_analysis_writer.cpp", + ], + LIBDEPS_PRIVATE=[ + "$BUILD_DIR/mongo/db/dbdirectclient", + '$BUILD_DIR/mongo/db/ops/write_ops_parsers', + '$BUILD_DIR/mongo/db/server_base', + '$BUILD_DIR/mongo/db/service_context', + '$BUILD_DIR/mongo/db/shard_role', + '$BUILD_DIR/mongo/idl/idl_parser', + '$BUILD_DIR/mongo/s/analyze_shard_key_common', + ], +) + +env.Library( target='sharding_runtime_d', source=[ 'active_migrations_registry.cpp', @@ -228,6 +254,8 @@ env.Library( '$BUILD_DIR/mongo/db/transaction/transaction_operations', '$BUILD_DIR/mongo/s/common_s', '$BUILD_DIR/mongo/util/future_util', + 'balancer_stats_registry', + 'query_analysis_writer', 'sharding_catalog', 'sharding_logging', ], @@ -550,6 +578,7 @@ env.Library( '$BUILD_DIR/mongo/s/sharding_initialization', '$BUILD_DIR/mongo/s/sharding_router_api', '$BUILD_DIR/mongo/s/startup_initialization', + 'balancer_stats_registry', 'forwardable_operation_metadata', 'sharding_catalog', 'sharding_logging', @@ -653,6 +682,7 @@ env.CppUnitTest( 'op_observer_sharding_test.cpp', 'operation_sharding_state_test.cpp', 'persistent_task_queue_test.cpp', + 'query_analysis_writer_test.cpp', 'range_deleter_service_test.cpp', 'range_deleter_service_test_util.cpp', 'range_deleter_service_op_observer_test.cpp', @@ -734,6 +764,7 @@ env.CppUnitTest( '$BUILD_DIR/mongo/executor/thread_pool_task_executor_test_fixture', '$BUILD_DIR/mongo/s/catalog/sharding_catalog_client_mock', '$BUILD_DIR/mongo/s/sharding_router_test_fixture', + 'query_analysis_writer', 'shard_server_test_fixture', 'sharding_catalog', 'sharding_commands_d', diff --git a/src/mongo/db/s/balancer/balancer.cpp b/src/mongo/db/s/balancer/balancer.cpp index f1c96cc97cf..9103ec7f04c 100644 --- a/src/mongo/db/s/balancer/balancer.cpp +++ b/src/mongo/db/s/balancer/balancer.cpp @@ -281,7 +281,7 @@ Balancer::Balancer() std::make_unique<BalancerChunkSelectionPolicyImpl>(_clusterStats.get(), _random)), _commandScheduler(std::make_unique<BalancerCommandsSchedulerImpl>()), _defragmentationPolicy(std::make_unique<BalancerDefragmentationPolicyImpl>( - _clusterStats.get(), _random, [this]() { _onActionsStreamPolicyStateUpdate(); })), + _clusterStats.get(), [this]() { _onActionsStreamPolicyStateUpdate(); })), _clusterChunksResizePolicy(std::make_unique<ClusterChunksResizePolicyImpl>( [this] { _onActionsStreamPolicyStateUpdate(); })) {} @@ -1085,7 +1085,7 @@ int Balancer::_moveChunks(OperationContext* opCtx, opCtx, migrateInfo.uuid, repl::ReadConcernLevel::kMajorityReadConcern); ShardingCatalogManager::get(opCtx)->splitOrMarkJumbo( - opCtx, collection.getNss(), migrateInfo.minKey); + opCtx, collection.getNss(), migrateInfo.minKey, migrateInfo.getMaxChunkSizeBytes()); continue; } @@ -1151,7 +1151,12 @@ BalancerCollectionStatusResponse Balancer::getBalancerStatusForNs(OperationConte uasserted(ErrorCodes::NamespaceNotSharded, "Collection unsharded or undefined"); } - const auto maxChunkSizeMB = getMaxChunkSizeMB(opCtx, coll); + + const auto maxChunkSizeBytes = getMaxChunkSizeBytes(opCtx, coll); + double maxChunkSizeMB = (double)maxChunkSizeBytes / (1024 * 1024); + // Keep only 2 decimal digits to return a readable value + maxChunkSizeMB = std::ceil(maxChunkSizeMB * 100.0) / 100.0; + BalancerCollectionStatusResponse response(maxChunkSizeMB, true /*balancerCompliant*/); auto setViolationOnResponse = [&response](const StringData& reason, const boost::optional<BSONObj>& details = diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp index 1c5a88a6100..c8d51184bc3 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.cpp @@ -398,7 +398,8 @@ public: std::move(collectionChunks), std::move(shardInfos), std::move(collectionZones), - smallChunkSizeThresholdBytes)); + smallChunkSizeThresholdBytes, + maxChunkSizeBytes)); } DefragmentationPhaseEnum getType() const override { @@ -462,7 +463,7 @@ public: auto smallChunkVersion = getShardVersion(opCtx, nextSmallChunk->shard, _nss); _outstandingMigrations.emplace_back(nextSmallChunk, targetSibling); return _outstandingMigrations.back().asMigrateInfo( - _uuid, _nss, smallChunkVersion.placementVersion()); + _uuid, _nss, smallChunkVersion.placementVersion(), _maxChunkSizeBytes); } return boost::none; @@ -701,7 +702,8 @@ private: MigrateInfo asMigrateInfo(const UUID& collUuid, const NamespaceString& nss, - const ChunkVersion& version) const { + const ChunkVersion& version, + uint64_t maxChunkSizeBytes) const { return MigrateInfo(chunkToMergeWith->shard, chunkToMove->shard, nss, @@ -709,7 +711,8 @@ private: chunkToMove->range.getMin(), chunkToMove->range.getMax(), version, - ForceJumbo::kForceBalancer); + ForceJumbo::kForceBalancer, + maxChunkSizeBytes); } ChunkRange asMergedRange() const { @@ -774,6 +777,8 @@ private: const int64_t _smallChunkSizeThresholdBytes; + const uint64_t _maxChunkSizeBytes; + bool _aborted{false}; DefragmentationPhaseEnum _nextPhase{DefragmentationPhaseEnum::kMergeChunks}; @@ -783,7 +788,8 @@ private: std::vector<ChunkType>&& collectionChunks, stdx::unordered_map<ShardId, ShardInfo>&& shardInfos, ZoneInfo&& collectionZones, - uint64_t smallChunkSizeThresholdBytes) + uint64_t smallChunkSizeThresholdBytes, + uint64_t maxChunkSizeBytes) : _nss(nss), _uuid(uuid), _collectionChunks(), @@ -794,7 +800,8 @@ private: _actionableMerges(), _outstandingMerges(), _zoneInfo(std::move(collectionZones)), - _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes) { + _smallChunkSizeThresholdBytes(smallChunkSizeThresholdBytes), + _maxChunkSizeBytes(maxChunkSizeBytes) { // Load the collection routing table in a std::list to ease later manipulation for (auto&& chunk : collectionChunks) { diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h index bc41346ca7f..bab4a6c58cf 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_impl.h @@ -72,9 +72,10 @@ class BalancerDefragmentationPolicyImpl : public BalancerDefragmentationPolicy { public: BalancerDefragmentationPolicyImpl(ClusterStatistics* clusterStats, - BalancerRandomSource& random, const std::function<void()>& onStateUpdated) - : _clusterStats(clusterStats), _random(random), _onStateUpdated(onStateUpdated) {} + : _clusterStats(clusterStats), + _random(std::random_device{}()), + _onStateUpdated(onStateUpdated) {} ~BalancerDefragmentationPolicyImpl() {} @@ -145,7 +146,7 @@ private: ClusterStatistics* const _clusterStats; - BalancerRandomSource& _random; + BalancerRandomSource _random; const std::function<void()> _onStateUpdated; diff --git a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp index 8646b5d965a..e9370bbb4d3 100644 --- a/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp +++ b/src/mongo/db/s/balancer/balancer_defragmentation_policy_test.cpp @@ -29,7 +29,6 @@ #include "mongo/db/dbdirectclient.h" #include "mongo/db/s/balancer/balancer_defragmentation_policy_impl.h" -#include "mongo/db/s/balancer/balancer_random.h" #include "mongo/db/s/balancer/cluster_statistics_mock.h" #include "mongo/db/s/config/config_server_test_fixture.h" #include "mongo/idl/server_parameter_test_util.h" @@ -75,9 +74,7 @@ protected: const std::function<void()> onDefragmentationStateUpdated = [] {}; BalancerDefragmentationPolicyTest() - : _clusterStats(), - _random(std::random_device{}()), - _defragmentationPolicy(&_clusterStats, _random, onDefragmentationStateUpdated) {} + : _clusterStats(), _defragmentationPolicy(&_clusterStats, onDefragmentationStateUpdated) {} CollectionType setupCollectionWithPhase( const std::vector<ChunkType>& chunkList, @@ -137,7 +134,6 @@ protected: } ClusterStatisticsMock _clusterStats; - BalancerRandomSource _random; BalancerDefragmentationPolicyImpl _defragmentationPolicy; ShardStatistics buildShardStats(ShardId id, diff --git a/src/mongo/db/s/balancer/balancer_policy.cpp b/src/mongo/db/s/balancer/balancer_policy.cpp index 4e0ff4c3ccb..a8ad96ab2c7 100644 --- a/src/mongo/db/s/balancer/balancer_policy.cpp +++ b/src/mongo/db/s/balancer/balancer_policy.cpp @@ -36,6 +36,7 @@ #include "mongo/db/s/balancer/type_migration.h" #include "mongo/logv2/log.h" +#include "mongo/s/balancer_configuration.h" #include "mongo/s/catalog/type_shard.h" #include "mongo/s/catalog/type_tags.h" #include "mongo/s/grid.h" @@ -450,7 +451,16 @@ MigrateInfosWithReason BalancerPolicy::balance( } invariant(to != stat.shardId); - migrations.emplace_back(to, distribution.nss(), chunk, ForceJumbo::kForceBalancer); + + auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> { + if (collDataSizeInfo.has_value()) { + return collDataSizeInfo->maxChunkSizeBytes; + } + return boost::none; + }(); + + migrations.emplace_back( + to, distribution.nss(), chunk, ForceJumbo::kForceBalancer, maxChunkSizeBytes); if (firstReason == MigrationReason::none) { firstReason = MigrationReason::drain; } @@ -513,11 +523,20 @@ MigrateInfosWithReason BalancerPolicy::balance( } invariant(to != stat.shardId); + + auto maxChunkSizeBytes = [&]() -> boost::optional<int64_t> { + if (collDataSizeInfo.has_value()) { + return collDataSizeInfo->maxChunkSizeBytes; + } + return boost::none; + }(); + migrations.emplace_back(to, distribution.nss(), chunk, forceJumbo ? ForceJumbo::kForceBalancer - : ForceJumbo::kDoNotForce); + : ForceJumbo::kDoNotForce, + maxChunkSizeBytes); if (firstReason == MigrationReason::none) { firstReason = MigrationReason::zoneViolation; } @@ -796,7 +815,8 @@ string ZoneRange::toString() const { MigrateInfo::MigrateInfo(const ShardId& a_to, const NamespaceString& a_nss, const ChunkType& a_chunk, - const ForceJumbo a_forceJumbo) + const ForceJumbo a_forceJumbo, + boost::optional<int64_t> maxChunkSizeBytes) : nss(a_nss), uuid(a_chunk.getCollectionUUID()) { invariant(a_to.isValid()); @@ -807,6 +827,7 @@ MigrateInfo::MigrateInfo(const ShardId& a_to, maxKey = a_chunk.getMax(); version = a_chunk.getVersion(); forceJumbo = a_forceJumbo; + optMaxChunkSizeBytes = maxChunkSizeBytes; } MigrateInfo::MigrateInfo(const ShardId& a_to, @@ -858,6 +879,10 @@ string MigrateInfo::toString() const { << ", to " << to; } +boost::optional<int64_t> MigrateInfo::getMaxChunkSizeBytes() const { + return optMaxChunkSizeBytes; +} + SplitInfo::SplitInfo(const ShardId& inShardId, const NamespaceString& inNss, const ChunkVersion& inCollectionVersion, diff --git a/src/mongo/db/s/balancer/balancer_policy.h b/src/mongo/db/s/balancer/balancer_policy.h index bd464ca9499..a9a1f307310 100644 --- a/src/mongo/db/s/balancer/balancer_policy.h +++ b/src/mongo/db/s/balancer/balancer_policy.h @@ -60,7 +60,8 @@ struct MigrateInfo { MigrateInfo(const ShardId& a_to, const NamespaceString& a_nss, const ChunkType& a_chunk, - ForceJumbo a_forceJumbo); + ForceJumbo a_forceJumbo, + boost::optional<int64_t> maxChunkSizeBytes = boost::none); MigrateInfo(const ShardId& a_to, const ShardId& a_from, @@ -78,6 +79,8 @@ struct MigrateInfo { std::string toString() const; + boost::optional<int64_t> getMaxChunkSizeBytes() const; + NamespaceString nss; UUID uuid; ShardId to; diff --git a/src/mongo/db/s/balancer_stats_registry.cpp b/src/mongo/db/s/balancer_stats_registry.cpp index bfd789acb52..0f11a578c23 100644 --- a/src/mongo/db/s/balancer_stats_registry.cpp +++ b/src/mongo/db/s/balancer_stats_registry.cpp @@ -29,6 +29,7 @@ #include "mongo/db/s/balancer_stats_registry.h" +#include "mongo/db/catalog_raii.h" #include "mongo/db/dbdirectclient.h" #include "mongo/db/pipeline/aggregate_command_gen.h" #include "mongo/db/repl/replication_coordinator.h" @@ -58,23 +59,6 @@ ThreadPool::Options makeDefaultThreadPoolOptions() { } } // namespace -ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx) - // TODO SERVER-62491 Use system tenantId for DBLock - : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX), - _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_X) {} - -// Take DB and Collection lock in mode IX as well as collection UUID lock to serialize with -// operations that take the above version of the ScopedRangeDeleterLock such as FCV downgrade and -// BalancerStatsRegistry initialization. -ScopedRangeDeleterLock::ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid) - // TODO SERVER-62491 Use system tenantId for DBLock - : _configLock(opCtx, DatabaseName(boost::none, NamespaceString::kConfigDb), MODE_IX), - _rangeDeletionLock(opCtx, NamespaceString::kRangeDeletionNamespace, MODE_IX), - _collectionUuidLock(Lock::ResourceLock( - opCtx, - ResourceId(RESOURCE_MUTEX, "RangeDeleterCollLock::" + collectionUuid.toString()), - MODE_X)) {} - const ReplicaSetAwareServiceRegistry::Registerer<BalancerStatsRegistry> balancerStatsRegistryRegisterer("BalancerStatsRegistry"); @@ -131,9 +115,13 @@ void BalancerStatsRegistry::initializeAsync(OperationContext* opCtx) { LOGV2_DEBUG(6419601, 2, "Initializing BalancerStatsRegistry"); try { - // Lock the range deleter to prevent - // concurrent modifications of orphans count - ScopedRangeDeleterLock rangeDeleterLock(opCtx); + // Lock the range deleter to prevent concurrent modifications of orphans count + ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_S); + // The collection lock is needed to serialize with direct writes to + // config.rangeDeletions + AutoGetCollection rangeDeletionLock( + opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S); + // Load current ophans count from disk _loadOrphansCount(opCtx); LOGV2_DEBUG(6419602, 2, "Completed BalancerStatsRegistry initialization"); diff --git a/src/mongo/db/s/balancer_stats_registry.h b/src/mongo/db/s/balancer_stats_registry.h index 473285b627c..fdc68ab1021 100644 --- a/src/mongo/db/s/balancer_stats_registry.h +++ b/src/mongo/db/s/balancer_stats_registry.h @@ -38,18 +38,19 @@ namespace mongo { /** - * Acquires the config db lock in IX mode and the collection lock for config.rangeDeletions in X - * mode. + * Scoped lock to synchronize with the execution of range deletions. + * The range-deleter acquires a scoped lock in IX mode while orphans are being deleted. + * Acquiring the scoped lock in MODE_X ensures that no orphan counter in `config.rangeDeletions` + * entries is going to be updated concurrently. */ class ScopedRangeDeleterLock { public: - ScopedRangeDeleterLock(OperationContext* opCtx); - ScopedRangeDeleterLock(OperationContext* opCtx, const UUID& collectionUuid); + ScopedRangeDeleterLock(OperationContext* opCtx, LockMode mode) + : _resourceLock(opCtx, _mutex.getRid(), mode) {} private: - Lock::DBLock _configLock; - Lock::CollectionLock _rangeDeletionLock; - boost::optional<Lock::ResourceLock> _collectionUuidLock; + const Lock::ResourceLock _resourceLock; + static inline const Lock::ResourceMutex _mutex{"ScopedRangeDeleterLock"}; }; /** diff --git a/src/mongo/db/s/cluster_count_cmd_d.cpp b/src/mongo/db/s/cluster_count_cmd_d.cpp index 27c64148cc9..a593a299b4f 100644 --- a/src/mongo/db/s/cluster_count_cmd_d.cpp +++ b/src/mongo/db/s/cluster_count_cmd_d.cpp @@ -62,6 +62,11 @@ struct ClusterCountCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster count command on a mongod"); + } }; ClusterCountCmdBase<ClusterCountCmdD> clusterCountCmdD; diff --git a/src/mongo/db/s/cluster_find_cmd_d.cpp b/src/mongo/db/s/cluster_find_cmd_d.cpp index e97f07b70e7..e98315c06ad 100644 --- a/src/mongo/db/s/cluster_find_cmd_d.cpp +++ b/src/mongo/db/s/cluster_find_cmd_d.cpp @@ -61,6 +61,11 @@ struct ClusterFindCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster find command on a mongod"); + } }; ClusterFindCmdBase<ClusterFindCmdD> clusterFindCmdD; diff --git a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp index 437e7418355..1b3eed4242a 100644 --- a/src/mongo/db/s/cluster_pipeline_cmd_d.cpp +++ b/src/mongo/db/s/cluster_pipeline_cmd_d.cpp @@ -61,6 +61,11 @@ struct ClusterPipelineCommandD { uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster aggregate command on a mongod"); + } + static AggregateCommandRequest parseAggregationRequest( OperationContext* opCtx, const OpMsgRequest& opMsgRequest, diff --git a/src/mongo/db/s/cluster_write_cmd_d.cpp b/src/mongo/db/s/cluster_write_cmd_d.cpp index c8f75d69e15..66167ab2b30 100644 --- a/src/mongo/db/s/cluster_write_cmd_d.cpp +++ b/src/mongo/db/s/cluster_write_cmd_d.cpp @@ -57,6 +57,11 @@ struct ClusterInsertCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster insert command on a mongod"); + } }; ClusterInsertCmdBase<ClusterInsertCmdD> clusterInsertCmdD; @@ -83,6 +88,10 @@ struct ClusterUpdateCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, "Explain on a clusterDelete is not supported"); + } }; ClusterUpdateCmdBase<ClusterUpdateCmdD> clusterUpdateCmdD; @@ -109,6 +118,11 @@ struct ClusterDeleteCmdD { // which triggers an invariant, so only shard servers can run this. uassertStatusOK(ShardingState::get(opCtx)->canAcceptShardedCommands()); } + + static void checkCanExplainHere(OperationContext* opCtx) { + uasserted(ErrorCodes::CommandNotSupported, + "Cannot explain a cluster delete command on a mongod"); + } }; ClusterDeleteCmdBase<ClusterDeleteCmdD> clusterDeleteCmdD; diff --git a/src/mongo/db/s/collection_sharding_runtime.cpp b/src/mongo/db/s/collection_sharding_runtime.cpp index 89ceb0cfd4c..ee530f6ea04 100644 --- a/src/mongo/db/s/collection_sharding_runtime.cpp +++ b/src/mongo/db/s/collection_sharding_runtime.cpp @@ -486,8 +486,10 @@ void CollectionShardingRuntime::resetShardVersionRecoverRefreshFuture() { _shardVersionInRecoverOrRefresh = boost::none; } -boost::optional<Timestamp> CollectionShardingRuntime::getIndexVersion(OperationContext* opCtx) { - return _globalIndexesInfo ? _globalIndexesInfo->getVersion() : boost::none; +boost::optional<CollectionIndexes> CollectionShardingRuntime::getCollectionIndexes( + OperationContext* opCtx) { + return _globalIndexesInfo ? boost::make_optional(_globalIndexesInfo->getCollectionIndexes()) + : boost::none; } boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes( @@ -497,22 +499,22 @@ boost::optional<GlobalIndexesCache>& CollectionShardingRuntime::getIndexes( void CollectionShardingRuntime::addIndex(OperationContext* opCtx, const IndexCatalogType& index, - const Timestamp& indexVersion) { + const CollectionIndexes& collectionIndexes) { if (_globalIndexesInfo) { - _globalIndexesInfo->add(index, indexVersion); + _globalIndexesInfo->add(index, collectionIndexes); } else { IndexCatalogTypeMap indexMap; indexMap.emplace(index.getName(), index); - _globalIndexesInfo.emplace(indexVersion, std::move(indexMap)); + _globalIndexesInfo.emplace(collectionIndexes, std::move(indexMap)); } } void CollectionShardingRuntime::removeIndex(OperationContext* opCtx, const std::string& name, - const Timestamp& indexVersion) { + const CollectionIndexes& collectionIndexes) { tassert( 7019500, "Index information does not exist on CSR", _globalIndexesInfo.is_initialized()); - _globalIndexesInfo->remove(name, indexVersion); + _globalIndexesInfo->remove(name, collectionIndexes); } void CollectionShardingRuntime::clearIndexes(OperationContext* opCtx) { diff --git a/src/mongo/db/s/collection_sharding_runtime.h b/src/mongo/db/s/collection_sharding_runtime.h index 44e46d0bc4a..786425d8e92 100644 --- a/src/mongo/db/s/collection_sharding_runtime.h +++ b/src/mongo/db/s/collection_sharding_runtime.h @@ -229,7 +229,7 @@ public: /** * Gets an index version under a lock. */ - boost::optional<Timestamp> getIndexVersion(OperationContext* opCtx); + boost::optional<CollectionIndexes> getCollectionIndexes(OperationContext* opCtx); /** * Gets the index list under a lock. @@ -241,17 +241,17 @@ public: */ void addIndex(OperationContext* opCtx, const IndexCatalogType& index, - const Timestamp& indexVersion); + const CollectionIndexes& collectionIndexes); /** * Removes an index from the shard-role index info under a lock. */ void removeIndex(OperationContext* opCtx, const std::string& name, - const Timestamp& indexVersion); + const CollectionIndexes& collectionIndexes); /** - * Clears the shard-role index info and set the indexVersion to boost::none. + * Clears the shard-role index info and set the collectionIndexes to boost::none. */ void clearIndexes(OperationContext* opCtx); diff --git a/src/mongo/db/s/collection_sharding_runtime_test.cpp b/src/mongo/db/s/collection_sharding_runtime_test.cpp index 6e792a8db0e..334225663e3 100644 --- a/src/mongo/db/s/collection_sharding_runtime_test.cpp +++ b/src/mongo/db/s/collection_sharding_runtime_test.cpp @@ -377,8 +377,10 @@ public: AutoGetCollection autoColl(operationContext(), kTestNss, MODE_IX); _uuid = autoColl.getCollection()->uuid(); - RangeDeleterService::get(operationContext())->onStepUpComplete(operationContext(), 0L); - RangeDeleterService::get(operationContext())->_waitForRangeDeleterServiceUp_FOR_TESTING(); + auto opCtx = operationContext(); + RangeDeleterService::get(opCtx)->onStartup(opCtx); + RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L); + RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING(); } void tearDown() override { @@ -611,8 +613,7 @@ TEST_F(CollectionShardingRuntimeWithCatalogTest, TestGlobalIndexesCache) { opCtx, kTestNss, "x_1", BSON("x" << 1), BSONObj(), uuid(), indexVersion, boost::none); ASSERT_EQ(true, csr()->getIndexes(opCtx).is_initialized()); - ASSERT_EQ(indexVersion, *csr()->getIndexes(opCtx)->getVersion()); - ASSERT_EQ(indexVersion, *csr()->getIndexVersion(opCtx)); + ASSERT_EQ(CollectionIndexes(uuid(), indexVersion), *csr()->getCollectionIndexes(opCtx)); } } // namespace } // namespace mongo diff --git a/src/mongo/db/s/config/sharding_catalog_manager.h b/src/mongo/db/s/config/sharding_catalog_manager.h index 4e8a64f59a6..fba732dce7a 100644 --- a/src/mongo/db/s/config/sharding_catalog_manager.h +++ b/src/mongo/db/s/config/sharding_catalog_manager.h @@ -359,7 +359,8 @@ public: */ void splitOrMarkJumbo(OperationContext* opCtx, const NamespaceString& nss, - const BSONObj& minKey); + const BSONObj& minKey, + boost::optional<int64_t> optMaxChunkSizeBytes); /** * In a transaction, sets the 'allowMigrations' to the requested state and bumps the collection diff --git a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp index 2b0692acb5d..93ac3b56099 100644 --- a/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp +++ b/src/mongo/db/s/config/sharding_catalog_manager_chunk_operations.cpp @@ -1775,19 +1775,31 @@ void ShardingCatalogManager::bumpMultipleCollectionVersionsAndChangeMetadataInTx void ShardingCatalogManager::splitOrMarkJumbo(OperationContext* opCtx, const NamespaceString& nss, - const BSONObj& minKey) { + const BSONObj& minKey, + boost::optional<int64_t> optMaxChunkSizeBytes) { const auto cm = uassertStatusOK( Grid::get(opCtx)->catalogCache()->getShardedCollectionRoutingInfoWithRefresh(opCtx, nss)); auto chunk = cm.findIntersectingChunkWithSimpleCollation(minKey); try { - const auto splitPoints = uassertStatusOK(shardutil::selectChunkSplitPoints( - opCtx, - chunk.getShardId(), - nss, - cm.getShardKeyPattern(), - ChunkRange(chunk.getMin(), chunk.getMax()), - Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes())); + const auto maxChunkSizeBytes = [&]() -> int64_t { + if (optMaxChunkSizeBytes.has_value()) { + return *optMaxChunkSizeBytes; + } + + auto coll = Grid::get(opCtx)->catalogClient()->getCollection( + opCtx, nss, repl::ReadConcernLevel::kMajorityReadConcern); + return coll.getMaxChunkSizeBytes().value_or( + Grid::get(opCtx)->getBalancerConfiguration()->getMaxChunkSizeBytes()); + }(); + + const auto splitPoints = uassertStatusOK( + shardutil::selectChunkSplitPoints(opCtx, + chunk.getShardId(), + nss, + cm.getShardKeyPattern(), + ChunkRange(chunk.getMin(), chunk.getMax()), + maxChunkSizeBytes)); if (splitPoints.empty()) { LOGV2(21873, diff --git a/src/mongo/db/s/query_analysis_op_observer.cpp b/src/mongo/db/s/query_analysis_op_observer.cpp index 84606c7f4c7..658c50c992d 100644 --- a/src/mongo/db/s/query_analysis_op_observer.cpp +++ b/src/mongo/db/s/query_analysis_op_observer.cpp @@ -31,6 +31,7 @@ #include "mongo/db/s/query_analysis_coordinator.h" #include "mongo/db/s/query_analysis_op_observer.h" +#include "mongo/db/s/query_analysis_writer.h" #include "mongo/logv2/log.h" #include "mongo/s/analyze_shard_key_util.h" #include "mongo/s/catalog/type_mongos.h" @@ -88,6 +89,17 @@ void QueryAnalysisOpObserver::onUpdate(OperationContext* opCtx, const OplogUpdat }); } } + + if (analyze_shard_key::supportsPersistingSampledQueries() && args.updateArgs->sampleId && + args.updateArgs->preImageDoc && opCtx->writesAreReplicated()) { + analyze_shard_key::QueryAnalysisWriter::get(opCtx) + .addDiff(*args.updateArgs->sampleId, + args.coll->ns(), + args.coll->uuid(), + *args.updateArgs->preImageDoc, + args.updateArgs->updatedDoc) + .getAsync([](auto) {}); + } } void QueryAnalysisOpObserver::aboutToDelete(OperationContext* opCtx, diff --git a/src/mongo/db/s/query_analysis_writer.cpp b/src/mongo/db/s/query_analysis_writer.cpp new file mode 100644 index 00000000000..19b942e1775 --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer.cpp @@ -0,0 +1,695 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ +#include "mongo/platform/basic.h" + +#include "mongo/db/s/query_analysis_writer.h" + +#include "mongo/client/connpool.h" +#include "mongo/db/catalog/collection_catalog.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/ops/write_ops.h" +#include "mongo/db/server_options.h" +#include "mongo/db/update/document_diff_calculator.h" +#include "mongo/executor/network_interface_factory.h" +#include "mongo/executor/thread_pool_task_executor.h" +#include "mongo/logv2/log.h" +#include "mongo/s/analyze_shard_key_documents_gen.h" +#include "mongo/s/analyze_shard_key_server_parameters_gen.h" +#include "mongo/s/write_ops/batched_command_response.h" +#include "mongo/util/concurrency/thread_pool.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kDefault + +namespace mongo { +namespace analyze_shard_key { + +namespace { + +MONGO_FAIL_POINT_DEFINE(disableQueryAnalysisWriter); +MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingLocally); +MONGO_FAIL_POINT_DEFINE(hangQueryAnalysisWriterBeforeWritingRemotely); + +const auto getQueryAnalysisWriter = ServiceContext::declareDecoration<QueryAnalysisWriter>(); + +constexpr int kMaxRetriesOnRetryableErrors = 5; +const WriteConcernOptions kMajorityWriteConcern{WriteConcernOptions::kMajority, + WriteConcernOptions::SyncMode::UNSET, + WriteConcernOptions::kWriteConcernTimeoutSystem}; + +// The size limit for the documents to an insert in a single batch. Leave some padding for other +// fields in the insert command. +constexpr int kMaxBSONObjSizeForInsert = BSONObjMaxUserSize - 500 * 1024; + +/* + * Returns true if this mongod can accept writes to the given collection. Unless the collection is + * in the "local" database, this will only return true if this mongod is a primary (or a + * standalone). + */ +bool canAcceptWrites(OperationContext* opCtx, const NamespaceString& ns) { + ShouldNotConflictWithSecondaryBatchApplicationBlock noPBWMBlock(opCtx->lockState()); + Lock::DBLock lk(opCtx, ns.dbName(), MODE_IS); + Lock::CollectionLock lock(opCtx, ns, MODE_IS); + return mongo::repl::ReplicationCoordinator::get(opCtx)->canAcceptWritesForDatabase(opCtx, + ns.db()); +} + +/* + * Runs the given write command against the given database locally, asserts that the top-level + * command is OK, then asserts the write status using the given 'uassertWriteStatusCb' callback. + * Returns the command response. + */ +BSONObj executeWriteCommandLocal(OperationContext* opCtx, + const std::string dbName, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + DBDirectClient client(opCtx); + BSONObj resObj; + + if (!client.runCommand(dbName, cmdObj, resObj)) { + uassertStatusOK(getStatusFromCommandResult(resObj)); + } + uassertWriteStatusCb(resObj); + + return resObj; +} + +/* + * Runs the given write command against the given database on the (remote) primary, asserts that the + * top-level command is OK, then asserts the write status using the given 'uassertWriteStatusCb' + * callback. Throws a PrimarySteppedDown error if no primary is found. Returns the command response. + */ +BSONObj executeWriteCommandRemote(OperationContext* opCtx, + const std::string dbName, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + auto hostAndPort = repl::ReplicationCoordinator::get(opCtx)->getCurrentPrimaryHostAndPort(); + + if (hostAndPort.empty()) { + uasserted(ErrorCodes::PrimarySteppedDown, "No primary exists currently"); + } + + auto conn = std::make_unique<ScopedDbConnection>(hostAndPort.toString()); + + if (auth::isInternalAuthSet()) { + uassertStatusOK(conn->get()->authenticateInternalUser()); + } + + DBClientBase* client = conn->get(); + ScopeGuard guard([&] { conn->done(); }); + try { + BSONObj resObj; + + if (!client->runCommand(dbName, cmdObj, resObj)) { + uassertStatusOK(getStatusFromCommandResult(resObj)); + } + uassertWriteStatusCb(resObj); + + return resObj; + } catch (...) { + guard.dismiss(); + conn->kill(); + throw; + } +} + +/* + * Runs the given write command against the given collection. If this mongod is currently the + * primary, runs the write command locally. Otherwise, runs the command on the remote primary. + * Internally asserts that the top-level command is OK, then asserts the write status using the + * given 'uassertWriteStatusCb' callback. Internally retries the write command on retryable errors + * (for kMaxRetriesOnRetryableErrors times) so the writes must be idempotent. Returns the + * command response. + */ +BSONObj executeWriteCommand(OperationContext* opCtx, + const NamespaceString& ns, + const BSONObj& cmdObj, + const std::function<void(const BSONObj&)>& uassertWriteStatusCb) { + const auto dbName = ns.db().toString(); + auto numRetries = 0; + + while (true) { + try { + if (canAcceptWrites(opCtx, ns)) { + // There is a window here where this mongod may step down after check above. In this + // case, a NotWritablePrimary error would be thrown. However, this is preferable to + // running the command while holding locks. + hangQueryAnalysisWriterBeforeWritingLocally.pauseWhileSet(opCtx); + return executeWriteCommandLocal(opCtx, dbName, cmdObj, uassertWriteStatusCb); + } + + hangQueryAnalysisWriterBeforeWritingRemotely.pauseWhileSet(opCtx); + return executeWriteCommandRemote(opCtx, dbName, cmdObj, uassertWriteStatusCb); + } catch (DBException& ex) { + if (ErrorCodes::isRetriableError(ex) && numRetries < kMaxRetriesOnRetryableErrors) { + numRetries++; + continue; + } + throw; + } + } + + return {}; +} + +struct SampledWriteCommandRequest { + UUID sampleId; + NamespaceString nss; + BSONObj cmd; // the BSON for a {Update,Delete,FindAndModify}CommandRequest +}; + +/* + * Returns a sampled update command for the update at 'opIndex' in the given update command. + */ +SampledWriteCommandRequest makeSampledUpdateCommandRequest( + const write_ops::UpdateCommandRequest& originalCmd, int opIndex) { + auto op = originalCmd.getUpdates()[opIndex]; + auto sampleId = op.getSampleId(); + invariant(sampleId); + + write_ops::UpdateCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)}); + sampledCmd.setLet(originalCmd.getLet()); + + return {*sampleId, + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +/* + * Returns a sampled delete command for the delete at 'opIndex' in the given delete command. + */ +SampledWriteCommandRequest makeSampledDeleteCommandRequest( + const write_ops::DeleteCommandRequest& originalCmd, int opIndex) { + auto op = originalCmd.getDeletes()[opIndex]; + auto sampleId = op.getSampleId(); + invariant(sampleId); + + write_ops::DeleteCommandRequest sampledCmd(originalCmd.getNamespace(), {std::move(op)}); + sampledCmd.setLet(originalCmd.getLet()); + + return {*sampleId, + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +/* + * Returns a sampled findAndModify command for the given findAndModify command. + */ +SampledWriteCommandRequest makeSampledFindAndModifyCommandRequest( + const write_ops::FindAndModifyCommandRequest& originalCmd) { + invariant(originalCmd.getSampleId()); + + write_ops::FindAndModifyCommandRequest sampledCmd(originalCmd.getNamespace()); + sampledCmd.setQuery(originalCmd.getQuery()); + sampledCmd.setUpdate(originalCmd.getUpdate()); + sampledCmd.setRemove(originalCmd.getRemove()); + sampledCmd.setUpsert(originalCmd.getUpsert()); + sampledCmd.setNew(originalCmd.getNew()); + sampledCmd.setSort(originalCmd.getSort()); + sampledCmd.setCollation(originalCmd.getCollation()); + sampledCmd.setArrayFilters(originalCmd.getArrayFilters()); + sampledCmd.setLet(originalCmd.getLet()); + sampledCmd.setSampleId(originalCmd.getSampleId()); + + return {*sampledCmd.getSampleId(), + sampledCmd.getNamespace(), + sampledCmd.toBSON(BSON("$db" << sampledCmd.getNamespace().db().toString()))}; +} + +} // namespace + +QueryAnalysisWriter& QueryAnalysisWriter::get(OperationContext* opCtx) { + return get(opCtx->getServiceContext()); +} + +QueryAnalysisWriter& QueryAnalysisWriter::get(ServiceContext* serviceContext) { + invariant(analyze_shard_key::isFeatureFlagEnabledIgnoreFCV(), + "Only support analyzing queries when the feature flag is enabled"); + invariant(serverGlobalParams.clusterRole == ClusterRole::ShardServer, + "Only support analyzing queries on a sharded cluster"); + return getQueryAnalysisWriter(serviceContext); +} + +void QueryAnalysisWriter::onStartup() { + auto serviceContext = getQueryAnalysisWriter.owner(this); + auto periodicRunner = serviceContext->getPeriodicRunner(); + invariant(periodicRunner); + + stdx::lock_guard<Latch> lk(_mutex); + + PeriodicRunner::PeriodicJob QueryWriterJob( + "QueryAnalysisQueryWriter", + [this](Client* client) { + if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) { + return; + } + auto opCtx = client->makeOperationContext(); + _flushQueries(opCtx.get()); + }, + Seconds(gQueryAnalysisWriterIntervalSecs)); + _periodicQueryWriter = periodicRunner->makeJob(std::move(QueryWriterJob)); + _periodicQueryWriter.start(); + + PeriodicRunner::PeriodicJob diffWriterJob( + "QueryAnalysisDiffWriter", + [this](Client* client) { + if (MONGO_unlikely(disableQueryAnalysisWriter.shouldFail())) { + return; + } + auto opCtx = client->makeOperationContext(); + _flushDiffs(opCtx.get()); + }, + Seconds(gQueryAnalysisWriterIntervalSecs)); + _periodicDiffWriter = periodicRunner->makeJob(std::move(diffWriterJob)); + _periodicDiffWriter.start(); + + ThreadPool::Options threadPoolOptions; + threadPoolOptions.maxThreads = gQueryAnalysisWriterMaxThreadPoolSize; + threadPoolOptions.minThreads = gQueryAnalysisWriterMinThreadPoolSize; + threadPoolOptions.threadNamePrefix = "QueryAnalysisWriter-"; + threadPoolOptions.poolName = "QueryAnalysisWriterThreadPool"; + threadPoolOptions.onCreateThread = [](const std::string& threadName) { + Client::initThread(threadName.c_str()); + }; + _executor = std::make_shared<executor::ThreadPoolTaskExecutor>( + std::make_unique<ThreadPool>(threadPoolOptions), + executor::makeNetworkInterface("QueryAnalysisWriterNetwork")); + _executor->startup(); +} + +void QueryAnalysisWriter::onShutdown() { + if (_executor) { + _executor->shutdown(); + _executor->join(); + } + if (_periodicQueryWriter.isValid()) { + _periodicQueryWriter.stop(); + } + if (_periodicDiffWriter.isValid()) { + _periodicDiffWriter.stop(); + } +} + +void QueryAnalysisWriter::_flushQueries(OperationContext* opCtx) { + try { + _flush(opCtx, NamespaceString::kConfigSampledQueriesNamespace, &_queries); + } catch (DBException& ex) { + LOGV2(7047300, + "Failed to flush queries, will try again at the next interval", + "error"_attr = redact(ex)); + } +} + +void QueryAnalysisWriter::_flushDiffs(OperationContext* opCtx) { + try { + _flush(opCtx, NamespaceString::kConfigSampledQueriesDiffNamespace, &_diffs); + } catch (DBException& ex) { + LOGV2(7075400, + "Failed to flush diffs, will try again at the next interval", + "error"_attr = redact(ex)); + } +} + +void QueryAnalysisWriter::_flush(OperationContext* opCtx, + const NamespaceString& ns, + Buffer* buffer) { + Buffer tmpBuffer; + // The indices of invalid documents, e.g. documents that fail to insert with DuplicateKey errors + // (i.e. duplicates) and BadValue errors. Such documents should not get added back to the buffer + // when the inserts below fail. + std::set<int> invalid; + { + stdx::lock_guard<Latch> lk(_mutex); + if (buffer->isEmpty()) { + return; + } + std::swap(tmpBuffer, *buffer); + } + ScopeGuard backSwapper([&] { + stdx::lock_guard<Latch> lk(_mutex); + for (int i = 0; i < tmpBuffer.getCount(); i++) { + if (invalid.find(i) == invalid.end()) { + buffer->add(tmpBuffer.at(i)); + } + } + }); + + // Insert the documents in batches from the back of the buffer so that we don't need to move the + // documents forward after each batch. + size_t baseIndex = tmpBuffer.getCount() - 1; + size_t maxBatchSize = gQueryAnalysisWriterMaxBatchSize.load(); + + while (!tmpBuffer.isEmpty()) { + std::vector<BSONObj> docsToInsert; + long long objSize = 0; + + size_t lastIndex = tmpBuffer.getCount(); // inclusive + while (lastIndex > 0 && docsToInsert.size() < maxBatchSize) { + // Check if the next document can fit in the batch. + auto doc = tmpBuffer.at(lastIndex - 1); + if (doc.objsize() + objSize >= kMaxBSONObjSizeForInsert) { + break; + } + lastIndex--; + objSize += doc.objsize(); + docsToInsert.push_back(std::move(doc)); + } + // We don't add a document that is above the size limit to the buffer so we should have + // added at least one document to 'docsToInsert'. + invariant(!docsToInsert.empty()); + + write_ops::InsertCommandRequest insertCmd(ns); + insertCmd.setDocuments(std::move(docsToInsert)); + insertCmd.setWriteCommandRequestBase([&] { + write_ops::WriteCommandRequestBase wcb; + wcb.setOrdered(false); + wcb.setBypassDocumentValidation(false); + return wcb; + }()); + auto insertCmdBson = insertCmd.toBSON( + {BSON(WriteConcernOptions::kWriteConcernField << kMajorityWriteConcern.toBSON())}); + + executeWriteCommand(opCtx, ns, std::move(insertCmdBson), [&](const BSONObj& resObj) { + BatchedCommandResponse response; + std::string errMsg; + + if (!response.parseBSON(resObj, &errMsg)) { + uasserted(ErrorCodes::FailedToParse, errMsg); + } + + if (response.isErrDetailsSet() && response.sizeErrDetails() > 0) { + boost::optional<write_ops::WriteError> firstWriteErr; + + for (const auto& err : response.getErrDetails()) { + if (err.getStatus() == ErrorCodes::DuplicateKey || + err.getStatus() == ErrorCodes::BadValue) { + LOGV2(7075402, + "Ignoring insert error", + "error"_attr = redact(err.getStatus())); + invalid.insert(baseIndex - err.getIndex()); + continue; + } + if (!firstWriteErr) { + // Save the error for later. Go through the rest of the errors to see if + // there are any invalid documents so they can be discarded from the buffer. + firstWriteErr.emplace(err); + } + } + if (firstWriteErr) { + uassertStatusOK(firstWriteErr->getStatus()); + } + } else { + uassertStatusOK(response.toStatus()); + } + }); + + tmpBuffer.truncate(lastIndex, objSize); + baseIndex -= lastIndex; + } + + backSwapper.dismiss(); +} + +void QueryAnalysisWriter::Buffer::add(BSONObj doc) { + if (doc.objsize() > kMaxBSONObjSizeForInsert) { + return; + } + _docs.push_back(std::move(doc)); + _numBytes += _docs.back().objsize(); +} + +void QueryAnalysisWriter::Buffer::truncate(size_t index, long long numBytes) { + invariant(index >= 0); + invariant(index < _docs.size()); + invariant(numBytes > 0); + invariant(numBytes <= _numBytes); + _docs.resize(index); + _numBytes -= numBytes; +} + +bool QueryAnalysisWriter::_exceedsMaxSizeBytes() { + stdx::lock_guard<Latch> lk(_mutex); + return _queries.getSize() + _diffs.getSize() >= gQueryAnalysisWriterMaxMemoryUsageBytes.load(); +} + +ExecutorFuture<void> QueryAnalysisWriter::addFindQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kFind, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addCountQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kCount, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDistinctQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kDistinct, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::addAggregateQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation) { + return _addReadQuery(sampleId, nss, SampledReadCommandNameEnum::kAggregate, filter, collation); +} + +ExecutorFuture<void> QueryAnalysisWriter::_addReadQuery(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation) { + invariant(_executor); + return ExecutorFuture<void>(_executor) + .then([this, + sampleId, + nss, + cmdName, + filter = filter.getOwned(), + collation = collation.getOwned()] { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, nss); + + if (!collUuid) { + LOGV2(7047301, "Found a sampled read query for non-existing collection"); + return; + } + + auto cmd = SampledReadCommand{filter.getOwned(), collation.getOwned()}; + auto doc = SampledReadQueryDocument{sampleId, nss, *collUuid, cmdName, cmd.toBSON()}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss](Status status) { + LOGV2(7047302, + "Failed to add read query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addUpdateQuery( + const write_ops::UpdateCommandRequest& updateCmd, int opIndex) { + invariant(updateCmd.getUpdates()[opIndex].getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, sampledUpdateCmd = makeSampledUpdateCommandRequest(updateCmd, opIndex)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledUpdateCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075300, + "Found a sampled update query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledUpdateCmd.sampleId, + sampledUpdateCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kUpdate, + std::move(sampledUpdateCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = updateCmd.getNamespace()](Status status) { + LOGV2(7075301, + "Failed to add update query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDeleteQuery( + const write_ops::DeleteCommandRequest& deleteCmd, int opIndex) { + invariant(deleteCmd.getDeletes()[opIndex].getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, sampledDeleteCmd = makeSampledDeleteCommandRequest(deleteCmd, opIndex)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledDeleteCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075302, + "Found a sampled delete query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledDeleteCmd.sampleId, + sampledDeleteCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kDelete, + std::move(sampledDeleteCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = deleteCmd.getNamespace()](Status status) { + LOGV2(7075303, + "Failed to add delete query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addFindAndModifyQuery( + const write_ops::FindAndModifyCommandRequest& findAndModifyCmd) { + invariant(findAndModifyCmd.getSampleId()); + invariant(_executor); + + return ExecutorFuture<void>(_executor) + .then([this, + sampledFindAndModifyCmd = + makeSampledFindAndModifyCommandRequest(findAndModifyCmd)]() { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + + auto collUuid = + CollectionCatalog::get(opCtx)->lookupUUIDByNSS(opCtx, sampledFindAndModifyCmd.nss); + + if (!collUuid) { + LOGV2_WARNING(7075304, + "Found a sampled findAndModify query for a non-existing collection"); + return; + } + + auto doc = SampledWriteQueryDocument{sampledFindAndModifyCmd.sampleId, + sampledFindAndModifyCmd.nss, + *collUuid, + SampledWriteCommandNameEnum::kFindAndModify, + std::move(sampledFindAndModifyCmd.cmd)}; + stdx::lock_guard<Latch> lk(_mutex); + _queries.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushQueries(opCtx); + } + }) + .onError([this, nss = findAndModifyCmd.getNamespace()](Status status) { + LOGV2(7075305, + "Failed to add findAndModify query", + "ns"_attr = nss, + "error"_attr = redact(status)); + }); +} + +ExecutorFuture<void> QueryAnalysisWriter::addDiff(const UUID& sampleId, + const NamespaceString& nss, + const UUID& collUuid, + const BSONObj& preImage, + const BSONObj& postImage) { + invariant(_executor); + return ExecutorFuture<void>(_executor) + .then([this, + sampleId, + nss, + collUuid, + preImage = preImage.getOwned(), + postImage = postImage.getOwned()]() { + auto diff = doc_diff::computeInlineDiff(preImage, postImage); + + if (!diff || diff->isEmpty()) { + return; + } + + auto doc = SampledQueryDiffDocument{sampleId, nss, collUuid, std::move(*diff)}; + stdx::lock_guard<Latch> lk(_mutex); + _diffs.add(doc.toBSON()); + }) + .then([this] { + if (_exceedsMaxSizeBytes()) { + auto opCtxHolder = cc().makeOperationContext(); + auto opCtx = opCtxHolder.get(); + _flushDiffs(opCtx); + } + }) + .onError([this, nss](Status status) { + LOGV2(7075401, "Failed to add diff", "ns"_attr = nss, "error"_attr = redact(status)); + }); +} + +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/query_analysis_writer.h b/src/mongo/db/s/query_analysis_writer.h new file mode 100644 index 00000000000..508d4903b65 --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer.h @@ -0,0 +1,214 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/operation_context.h" +#include "mongo/db/service_context.h" +#include "mongo/executor/task_executor.h" +#include "mongo/s/analyze_shard_key_common_gen.h" +#include "mongo/s/analyze_shard_key_util.h" +#include "mongo/s/write_ops/batched_command_request.h" +#include "mongo/util/periodic_runner.h" + +namespace mongo { +namespace analyze_shard_key { + +/** + * Owns the machinery for persisting sampled queries. That consists of the following: + * - The buffer that stores sampled queries and the periodic background job that inserts those + * queries into the local config.sampledQueries collection. + * - The buffer that stores diffs for sampled update queries and the periodic background job that + * inserts those diffs into the local config.sampledQueriesDiff collection. + * + * Currently, query sampling is only supported on a sharded cluster. So a writer must be a shardsvr + * mongod. If the mongod is a primary, it will execute the insert commands locally. If it is a + * secondary, it will perform the insert commands against the primary. + * + * The memory usage of the buffers is controlled by the 'queryAnalysisWriterMaxMemoryUsageBytes' + * server parameter. Upon adding a query or a diff that causes the total size of buffers to exceed + * the limit, the writer will flush the corresponding buffer immediately instead of waiting for it + * to get flushed later by the periodic job. + */ +class QueryAnalysisWriter final : public std::enable_shared_from_this<QueryAnalysisWriter> { + QueryAnalysisWriter(const QueryAnalysisWriter&) = delete; + QueryAnalysisWriter& operator=(const QueryAnalysisWriter&) = delete; + +public: + /** + * Temporarily stores documents to be written to disk. + */ + struct Buffer { + public: + /** + * Adds the given document to the buffer if its size is below the limit (i.e. + * BSONObjMaxUserSize - some padding) and increments the total number of bytes accordingly. + */ + void add(BSONObj doc); + + /** + * Removes the documents at 'index' onwards from the buffer and decrements the total number + * of the bytes by 'numBytes'. The caller must ensure that that 'numBytes' is indeed the + * total size of the documents being removed. + */ + void truncate(size_t index, long long numBytes); + + bool isEmpty() const { + return _docs.empty(); + } + + int getCount() const { + return _docs.size(); + } + + long long getSize() const { + return _numBytes; + } + + BSONObj at(size_t index) const { + return _docs[index]; + } + + private: + std::vector<BSONObj> _docs; + long long _numBytes = 0; + }; + + QueryAnalysisWriter() = default; + ~QueryAnalysisWriter() = default; + + QueryAnalysisWriter(QueryAnalysisWriter&& source) = delete; + QueryAnalysisWriter& operator=(QueryAnalysisWriter&& other) = delete; + + /** + * Obtains the service-wide QueryAnalysisWriter instance. + */ + static QueryAnalysisWriter& get(OperationContext* opCtx); + static QueryAnalysisWriter& get(ServiceContext* serviceContext); + + void onStartup(); + + void onShutdown(); + + ExecutorFuture<void> addFindQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addCountQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addDistinctQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addAggregateQuery(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& filter, + const BSONObj& collation); + + ExecutorFuture<void> addUpdateQuery(const write_ops::UpdateCommandRequest& updateCmd, + int opIndex); + ExecutorFuture<void> addDeleteQuery(const write_ops::DeleteCommandRequest& deleteCmd, + int opIndex); + + ExecutorFuture<void> addFindAndModifyQuery( + const write_ops::FindAndModifyCommandRequest& findAndModifyCmd); + + ExecutorFuture<void> addDiff(const UUID& sampleId, + const NamespaceString& nss, + const UUID& collUuid, + const BSONObj& preImage, + const BSONObj& postImage); + + int getQueriesCountForTest() const { + stdx::lock_guard<Latch> lk(_mutex); + return _queries.getCount(); + } + + void flushQueriesForTest(OperationContext* opCtx) { + _flushQueries(opCtx); + } + + int getDiffsCountForTest() const { + stdx::lock_guard<Latch> lk(_mutex); + return _diffs.getCount(); + } + + void flushDiffsForTest(OperationContext* opCtx) { + _flushDiffs(opCtx); + } + +private: + ExecutorFuture<void> _addReadQuery(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation); + + void _flushQueries(OperationContext* opCtx); + void _flushDiffs(OperationContext* opCtx); + + /** + * The helper for '_flushQueries' and '_flushDiffs'. Inserts the documents in 'buffer' into the + * collection 'ns' in batches, and removes all the inserted documents from 'buffer'. Internally + * retries the inserts on retryable errors for a fixed number of times. Ignores DuplicateKey + * errors since they are expected for the following reasons: + * - For the query buffer, a sampled query that is idempotent (e.g. a read or retryable write) + * could get added to the buffer (across nodes) more than once due to retries. + * - For the diff buffer, a sampled multi-update query could end up generating multiple diffs + * and each diff is identified using the sample id of the sampled query that creates it. + * + * Throws an error if the inserts fail with any other error. + */ + void _flush(OperationContext* opCtx, const NamespaceString& nss, Buffer* buffer); + + /** + * Returns true if the total size of the buffered queries and diffs has exceeded the maximum + * amount of memory that the writer is allowed to use. + */ + bool _exceedsMaxSizeBytes(); + + mutable Mutex _mutex = MONGO_MAKE_LATCH("QueryAnalysisWriter::_mutex"); + + PeriodicJobAnchor _periodicQueryWriter; + Buffer _queries; + + PeriodicJobAnchor _periodicDiffWriter; + Buffer _diffs; + + // Initialized on startup and joined on shutdown. + std::shared_ptr<executor::TaskExecutor> _executor; +}; + +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/query_analysis_writer_test.cpp b/src/mongo/db/s/query_analysis_writer_test.cpp new file mode 100644 index 00000000000..b17f39c82df --- /dev/null +++ b/src/mongo/db/s/query_analysis_writer_test.cpp @@ -0,0 +1,1281 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/s/query_analysis_writer.h" + +#include "mongo/bson/unordered_fields_bsonobj_comparator.h" +#include "mongo/db/db_raii.h" +#include "mongo/db/dbdirectclient.h" +#include "mongo/db/s/shard_server_test_fixture.h" +#include "mongo/db/update/document_diff_calculator.h" +#include "mongo/idl/server_parameter_test_util.h" +#include "mongo/logv2/log.h" +#include "mongo/s/analyze_shard_key_documents_gen.h" +#include "mongo/unittest/death_test.h" +#include "mongo/util/fail_point.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest + +namespace mongo { +namespace analyze_shard_key { +namespace { + +TEST(QueryAnalysisWriterBufferTest, AddBasic) { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc0 = BSON("a" << 0); + buffer.add(doc0); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc0.objsize()); + + auto doc1 = BSON("a" << BSON_ARRAY(0 << 1 << 2)); + buffer.add(doc1); + ASSERT_EQ(buffer.getCount(), 2); + ASSERT_EQ(buffer.getSize(), doc0.objsize() + doc1.objsize()); + + ASSERT_BSONOBJ_EQ(buffer.at(0), doc0); + ASSERT_BSONOBJ_EQ(buffer.at(1), doc1); +} + +TEST(QueryAnalysisWriterBufferTest, AddTooLarge) { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON(std::string(BSONObjMaxUserSize, 'a') << 1); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 0); + ASSERT_EQ(buffer.getSize(), 0); +} + +TEST(QueryAnalysisWriterBufferTest, TruncateBasic) { + auto testTruncateCommon = [](int oldCount, int newCount) { + auto buffer = QueryAnalysisWriter::Buffer(); + + std::vector<BSONObj> docs; + for (auto i = 0; i < oldCount; i++) { + docs.push_back(BSON("a" << i)); + } + // The documents have the same size. + auto docSize = docs.back().objsize(); + + for (const auto& doc : docs) { + buffer.add(doc); + } + ASSERT_EQ(buffer.getCount(), oldCount); + ASSERT_EQ(buffer.getSize(), oldCount * docSize); + + buffer.truncate(newCount, (oldCount - newCount) * docSize); + ASSERT_EQ(buffer.getCount(), newCount); + ASSERT_EQ(buffer.getSize(), newCount * docSize); + for (auto i = 0; i < newCount; i++) { + ASSERT_BSONOBJ_EQ(buffer.at(i), docs[i]); + } + }; + + testTruncateCommon(10 /* oldCount */, 6 /* newCount */); + testTruncateCommon(10 /* oldCount */, 0 /* newCount */); // Truncate all. + testTruncateCommon(10 /* oldCount */, 9 /* newCount */); // Truncate one. +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Negative, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(-1, doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Positive, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(2, doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Negative, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, -doc.objsize()); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Zero, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, 0); +} + +DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Positive, "invariant") { + auto buffer = QueryAnalysisWriter::Buffer(); + + auto doc = BSON("a" << 0); + buffer.add(doc); + ASSERT_EQ(buffer.getCount(), 1); + ASSERT_EQ(buffer.getSize(), doc.objsize()); + + buffer.truncate(0, doc.objsize() * 2); +} + +void assertBsonObjEqualUnordered(const BSONObj& lhs, const BSONObj& rhs) { + UnorderedFieldsBSONObjComparator comparator; + ASSERT_EQ(comparator.compare(lhs, rhs), 0); +} + +struct QueryAnalysisWriterTest : public ShardServerTestFixture { +public: + void setUp() { + ShardServerTestFixture::setUp(); + QueryAnalysisWriter::get(operationContext()).onStartup(); + + DBDirectClient client(operationContext()); + client.createCollection(nss0.toString()); + client.createCollection(nss1.toString()); + } + + void tearDown() { + QueryAnalysisWriter::get(operationContext()).onShutdown(); + ShardServerTestFixture::tearDown(); + } + +protected: + UUID getCollectionUUID(const NamespaceString& nss) const { + auto collectionCatalog = CollectionCatalog::get(operationContext()); + return *collectionCatalog->lookupUUIDByNSS(operationContext(), nss); + } + + BSONObj makeNonEmptyFilter() { + return BSON("_id" << UUID::gen()); + } + + BSONObj makeNonEmptyCollation() { + int strength = rand() % 5 + 1; + return BSON("locale" + << "en_US" + << "strength" << strength); + } + + /* + * Makes an UpdateCommandRequest for the collection 'nss' such that the command contains + * 'numOps' updates and the ones whose indices are in 'markForSampling' are marked for sampling. + * Returns the UpdateCommandRequest and the map storing the expected sampled + * UpdateCommandRequests by sample id, if any. + */ + std::pair<write_ops::UpdateCommandRequest, std::map<UUID, write_ops::UpdateCommandRequest>> + makeUpdateCommandRequest(const NamespaceString& nss, + int numOps, + std::set<int> markForSampling, + std::string filterFieldName = "a") { + write_ops::UpdateCommandRequest originalCmd(nss); + std::vector<write_ops::UpdateOpEntry> updateOps; // populated below. + originalCmd.setLet(let); + originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::UpdateCommandRequest> expectedSampledCmds; + + for (auto i = 0; i < numOps; i++) { + auto updateOp = write_ops::UpdateOpEntry( + BSON(filterFieldName << i), + write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << i)))); + updateOp.setC(BSON("x" << i)); + updateOp.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << i))}); + updateOp.setMulti(_getRandomBool()); + updateOp.setUpsert(_getRandomBool()); + updateOp.setUpsertSupplied(_getRandomBool()); + updateOp.setCollation(makeNonEmptyCollation()); + + if (markForSampling.find(i) != markForSampling.end()) { + auto sampleId = UUID::gen(); + updateOp.setSampleId(sampleId); + + write_ops::UpdateCommandRequest expectedSampledCmd = originalCmd; + expectedSampledCmd.setUpdates({updateOp}); + expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation( + boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + updateOps.push_back(updateOp); + } + originalCmd.setUpdates(updateOps); + + return {originalCmd, expectedSampledCmds}; + } + + /* + * Makes an DeleteCommandRequest for the collection 'nss' such that the command contains + * 'numOps' deletes and the ones whose indices are in 'markForSampling' are marked for sampling. + * Returns the DeleteCommandRequest and the map storing the expected sampled + * DeleteCommandRequests by sample id, if any. + */ + std::pair<write_ops::DeleteCommandRequest, std::map<UUID, write_ops::DeleteCommandRequest>> + makeDeleteCommandRequest(const NamespaceString& nss, + int numOps, + std::set<int> markForSampling, + std::string filterFieldName = "a") { + write_ops::DeleteCommandRequest originalCmd(nss); + std::vector<write_ops::DeleteOpEntry> deleteOps; // populated and set below. + originalCmd.setLet(let); + originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::DeleteCommandRequest> expectedSampledCmds; + + for (auto i = 0; i < numOps; i++) { + auto deleteOp = + write_ops::DeleteOpEntry(BSON(filterFieldName << i), _getRandomBool() /* multi */); + deleteOp.setCollation(makeNonEmptyCollation()); + + if (markForSampling.find(i) != markForSampling.end()) { + auto sampleId = UUID::gen(); + deleteOp.setSampleId(sampleId); + + write_ops::DeleteCommandRequest expectedSampledCmd = originalCmd; + expectedSampledCmd.setDeletes({deleteOp}); + expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation( + boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + deleteOps.push_back(deleteOp); + } + originalCmd.setDeletes(deleteOps); + + return {originalCmd, expectedSampledCmds}; + } + + /* + * Makes a FindAndModifyCommandRequest for the collection 'nss'. The findAndModify is an update + * if 'isUpdate' is true, and a remove otherwise. If 'markForSampling' is true, it is marked for + * sampling. Returns the FindAndModifyCommandRequest and the map storing the expected sampled + * FindAndModifyCommandRequests by sample id, if any. + */ + std::pair<write_ops::FindAndModifyCommandRequest, + std::map<UUID, write_ops::FindAndModifyCommandRequest>> + makeFindAndModifyCommandRequest(const NamespaceString& nss, + bool isUpdate, + bool markForSampling, + std::string filterFieldName = "a") { + write_ops::FindAndModifyCommandRequest originalCmd(nss); + originalCmd.setQuery(BSON(filterFieldName << 0)); + originalCmd.setUpdate( + write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << 0)))); + originalCmd.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << 10))}); + originalCmd.setSort(BSON("_id" << 1)); + if (isUpdate) { + originalCmd.setUpsert(_getRandomBool()); + originalCmd.setNew(_getRandomBool()); + } + originalCmd.setCollation(makeNonEmptyCollation()); + originalCmd.setLet(let); + originalCmd.setEncryptionInformation(encryptionInformation); + + std::map<UUID, write_ops::FindAndModifyCommandRequest> expectedSampledCmds; + if (markForSampling) { + auto sampleId = UUID::gen(); + originalCmd.setSampleId(sampleId); + + auto expectedSampledCmd = originalCmd; + expectedSampledCmd.setEncryptionInformation(boost::none); + expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd)); + } + + return {originalCmd, expectedSampledCmds}; + } + + void deleteSampledQueryDocuments() const { + DBDirectClient client(operationContext()); + client.remove(NamespaceString::kConfigSampledQueriesNamespace.toString(), BSONObj()); + } + + /** + * Returns the number of the documents for the collection 'nss' in the config.sampledQueries + * collection. + */ + int getSampledQueryDocumentsCount(const NamespaceString& nss) { + return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesNamespace, nss); + } + + /* + * Asserts that there is a sampled read query document with the given sample id and that it has + * the given fields. + */ + void assertSampledReadQueryDocument(const UUID& sampleId, + const NamespaceString& nss, + SampledReadCommandNameEnum cmdName, + const BSONObj& filter, + const BSONObj& collation) { + auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId); + auto parsedQueryDoc = + SampledReadQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedQueryDoc.getNs(), nss); + ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId); + ASSERT(parsedQueryDoc.getCmdName() == cmdName); + auto parsedCmd = SampledReadCommand::parse(IDLParserContext("QueryAnalysisWriterTest"), + parsedQueryDoc.getCmd()); + ASSERT_BSONOBJ_EQ(parsedCmd.getFilter(), filter); + ASSERT_BSONOBJ_EQ(parsedCmd.getCollation(), collation); + } + + /* + * Asserts that there is a sampled write query document with the given sample id and that it has + * the given fields. + */ + template <typename CommandRequestType> + void assertSampledWriteQueryDocument(const UUID& sampleId, + const NamespaceString& nss, + SampledWriteCommandNameEnum cmdName, + const CommandRequestType& expectedCmd) { + auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId); + auto parsedQueryDoc = + SampledWriteQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedQueryDoc.getNs(), nss); + ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId); + ASSERT(parsedQueryDoc.getCmdName() == cmdName); + auto parsedCmd = CommandRequestType::parse(IDLParserContext("QueryAnalysisWriterTest"), + parsedQueryDoc.getCmd()); + ASSERT_BSONOBJ_EQ(parsedCmd.toBSON({}), expectedCmd.toBSON({})); + } + + /* + * Returns the number of the documents for the collection 'nss' in the config.sampledQueriesDiff + * collection. + */ + int getDiffDocumentsCount(const NamespaceString& nss) { + return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesDiffNamespace, nss); + } + + /* + * Asserts that there is a sampled diff document with the given sample id and that it has + * the given fields. + */ + void assertDiffDocument(const UUID& sampleId, + const NamespaceString& nss, + const BSONObj& expectedDiff) { + auto doc = + _getConfigDocument(NamespaceString::kConfigSampledQueriesDiffNamespace, sampleId); + auto parsedDiffDoc = + SampledQueryDiffDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc); + + ASSERT_EQ(parsedDiffDoc.getNs(), nss); + ASSERT_EQ(parsedDiffDoc.getCollectionUuid(), getCollectionUUID(nss)); + ASSERT_EQ(parsedDiffDoc.getSampleId(), sampleId); + assertBsonObjEqualUnordered(parsedDiffDoc.getDiff(), expectedDiff); + } + + const NamespaceString nss0{"testDb", "testColl0"}; + const NamespaceString nss1{"testDb", "testColl1"}; + + // Test with both empty and non-empty filter and collation to verify that the + // QueryAnalysisWriter doesn't require filter or collation to be non-empty. + const BSONObj emptyFilter{}; + const BSONObj emptyCollation{}; + + const BSONObj let = BSON("x" << 1); + // Test with EncryptionInformation to verify that QueryAnalysisWriter does not persist the + // WriteCommandRequestBase fields, especially this sensitive field. + const EncryptionInformation encryptionInformation{BSON("foo" + << "bar")}; + +private: + bool _getRandomBool() { + return rand() % 2 == 0; + } + + /** + * Returns the number of the documents for the collection 'collNss' in the config collection + * 'configNss'. + */ + int _getConfigDocumentsCount(const NamespaceString& configNss, + const NamespaceString& collNss) const { + DBDirectClient client(operationContext()); + return client.count(configNss, BSON("ns" << collNss.toString())); + } + + /** + * Returns the document with the given _id in the config collection 'configNss'. + */ + BSONObj _getConfigDocument(const NamespaceString configNss, const UUID& id) const { + DBDirectClient client(operationContext()); + + FindCommandRequest findRequest{configNss}; + findRequest.setFilter(BSON("_id" << id)); + auto cursor = client.find(std::move(findRequest)); + ASSERT(cursor->more()); + return cursor->next(); + } + + RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", true}; + FailPointEnableBlock _fp{"disableQueryAnalysisWriter"}; +}; + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetIfFeatureFlagNotEnabled, "invariant") { + RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", + false}; + QueryAnalysisWriter::get(operationContext()); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnConfigServer, "invariant") { + serverGlobalParams.clusterRole = ClusterRole::ConfigServer; + QueryAnalysisWriter::get(operationContext()); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnNonShardServer, "invariant") { + serverGlobalParams.clusterRole = ClusterRole::None; + QueryAnalysisWriter::get(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, NoQueries) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + writer.flushQueriesForTest(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, FindQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testFindCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addFindQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testFindCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testFindCmdCommon(makeNonEmptyFilter(), emptyCollation); + testFindCmdCommon(emptyFilter, makeNonEmptyCollation()); + testFindCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, CountQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testCountCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addCountQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kCount, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testCountCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testCountCmdCommon(makeNonEmptyFilter(), emptyCollation); + testCountCmdCommon(emptyFilter, makeNonEmptyCollation()); + testCountCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, DistinctQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testDistinctCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addDistinctQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kDistinct, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testDistinctCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testDistinctCmdCommon(makeNonEmptyFilter(), emptyCollation); + testDistinctCmdCommon(emptyFilter, makeNonEmptyCollation()); + testDistinctCmdCommon(emptyFilter, emptyCollation); +} + +TEST_F(QueryAnalysisWriterTest, AggregateQuery) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto testAggregateCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) { + auto sampleId = UUID::gen(); + + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + + deleteSampledQueryDocuments(); + }; + + testAggregateCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation()); + testAggregateCmdCommon(makeNonEmptyFilter(), emptyCollation); + testAggregateCmdCommon(emptyFilter, makeNonEmptyCollation()); + testAggregateCmdCommon(emptyFilter, emptyCollation); +} + +DEATH_TEST_F(QueryAnalysisWriterTest, UpdateQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = makeUpdateCommandRequest(nss0, 1, {} /* markForSampling */); + writer.addUpdateQuery(originalCmd, 0).get(); +} + +TEST_F(QueryAnalysisWriterTest, UpdateQueriesMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeUpdateCommandRequest(nss0, 3, {0, 2} /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addUpdateQuery(originalCmd, 0).get(); + writer.addUpdateQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 2); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledCmd); + } +} + +DEATH_TEST_F(QueryAnalysisWriterTest, DeleteQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = makeDeleteCommandRequest(nss0, 1, {} /* markForSampling */); + writer.addDeleteQuery(originalCmd, 0).get(); +} + +TEST_F(QueryAnalysisWriterTest, DeleteQueriesMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeDeleteCommandRequest(nss0, 3, {1, 2} /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addDeleteQuery(originalCmd, 1).get(); + writer.addDeleteQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 2); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledCmd); + } +} + +DEATH_TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryNotMarkedForSampling, "invariant") { + auto& writer = QueryAnalysisWriter::get(operationContext()); + auto [originalCmd, _] = + makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, false /* markForSampling */); + writer.addFindAndModifyQuery(originalCmd).get(); +} + +TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryUpdateMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, true /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 1U); + auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin(); + + writer.addFindAndModifyQuery(originalCmd).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd); +} + +TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryRemoveMarkedForSampling) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto [originalCmd, expectedSampledCmds] = + makeFindAndModifyCommandRequest(nss0, false /* isUpdate */, true /* markForSampling */); + ASSERT_EQ(expectedSampledCmds.size(), 1U); + auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin(); + + writer.addFindAndModifyQuery(originalCmd).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd); +} + +TEST_F(QueryAnalysisWriterTest, MultipleQueriesAndCollections) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + // Make nss0 have one query. + auto [originalDeleteCmd, expectedSampledDeleteCmds] = + makeDeleteCommandRequest(nss1, 3, {1} /* markForSampling */); + ASSERT_EQ(expectedSampledDeleteCmds.size(), 1U); + auto [deleteSampleId, expectedSampledDeleteCmd] = *expectedSampledDeleteCmds.begin(); + + // Make nss1 have two queries. + auto [originalUpdateCmd, expectedSampledUpdateCmds] = + makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */); + ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U); + auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin(); + + auto countSampleId = UUID::gen(); + auto originalCountFilter = makeNonEmptyFilter(); + auto originalCountCollation = makeNonEmptyCollation(); + + writer.addDeleteQuery(originalDeleteCmd, 1).get(); + writer.addUpdateQuery(originalUpdateCmd, 0).get(); + writer.addCountQuery(countSampleId, nss1, originalCountFilter, originalCountCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 3); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(deleteSampleId, + expectedSampledDeleteCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledDeleteCmd); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 2); + assertSampledWriteQueryDocument(updateSampleId, + expectedSampledUpdateCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledUpdateCmd); + assertSampledReadQueryDocument(countSampleId, + nss1, + SampledReadCommandNameEnum::kCount, + originalCountFilter, + originalCountCollation); +} + +TEST_F(QueryAnalysisWriterTest, DuplicateQueries) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto findSampleId = UUID::gen(); + auto originalFindFilter = makeNonEmptyFilter(); + auto originalFindCollation = makeNonEmptyCollation(); + + auto [originalUpdateCmd, expectedSampledUpdateCmds] = + makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */); + ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U); + auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin(); + + auto countSampleId = UUID::gen(); + auto originalCountFilter = makeNonEmptyFilter(); + auto originalCountCollation = makeNonEmptyCollation(); + + writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument(findSampleId, + nss0, + SampledReadCommandNameEnum::kFind, + originalFindFilter, + originalFindCollation); + + writer.addUpdateQuery(originalUpdateCmd, 0).get(); + writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation) + .get(); // This is a duplicate. + writer.addCountQuery(countSampleId, nss0, originalCountFilter, originalCountCollation).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 3); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 3); + assertSampledWriteQueryDocument(updateSampleId, + expectedSampledUpdateCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledUpdateCmd); + assertSampledReadQueryDocument(findSampleId, + nss0, + SampledReadCommandNameEnum::kFind, + originalFindFilter, + originalFindCollation); + assertSampledReadQueryDocument(countSampleId, + nss0, + SampledReadCommandNameEnum::kCount, + originalCountFilter, + originalCountCollation); +} + +TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBatchSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2}; + auto numQueries = 5; + + std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds; + for (auto i = 0; i < numQueries; i++) { + auto sampleId = UUID::gen(); + auto filter = makeNonEmptyFilter(); + auto collation = makeNonEmptyCollation(); + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + expectedSampledCmds.push_back({sampleId, filter, collation}); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& [sampleId, filter, collation] : expectedSampledCmds) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + } +} + +TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBSONObjSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto numQueries = 3; + std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds; + for (auto i = 0; i < numQueries; i++) { + auto sampleId = UUID::gen(); + auto filter = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1); + auto collation = makeNonEmptyCollation(); + writer.addAggregateQuery(sampleId, nss0, filter, collation).get(); + expectedSampledCmds.push_back({sampleId, filter, collation}); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& [sampleId, filter, collation] : expectedSampledCmds) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddReadIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto sampleId0 = UUID::gen(); + auto filter0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1); + auto collation0 = makeNonEmptyCollation(); + + auto sampleId1 = UUID::gen(); + auto filter1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1); + auto collation1 = makeNonEmptyCollation(); + + writer.addFindQuery(sampleId0, nss0, filter0, collation0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addAggregateQuery(sampleId1, nss1, filter1, collation1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledReadQueryDocument( + sampleId0, nss0, SampledReadCommandNameEnum::kFind, filter0, collation0); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1); + assertSampledReadQueryDocument( + sampleId1, nss1, SampledReadCommandNameEnum::kAggregate, filter1, collation1); +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddUpdateIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + auto [originalCmd, expectedSampledCmds] = + makeUpdateCommandRequest(nss0, + 3, + {0, 2} /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addUpdateQuery(originalCmd, 0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addUpdateQuery(originalCmd, 2).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kUpdate, + expectedSampledCmd); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddDeleteIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + auto [originalCmd, expectedSampledCmds] = + makeDeleteCommandRequest(nss0, + 3, + {0, 1} /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds.size(), 2U); + + writer.addDeleteQuery(originalCmd, 0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addDeleteQuery(originalCmd, 1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2); + for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) { + assertSampledWriteQueryDocument(sampleId, + expectedSampledCmd.getNamespace(), + SampledWriteCommandNameEnum::kDelete, + expectedSampledCmd); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddFindAndModifyIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto [originalCmd0, expectedSampledCmds0] = makeFindAndModifyCommandRequest( + nss0, + true /* isUpdate */, + true /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds0.size(), 1U); + auto [sampleId0, expectedSampledCmd0] = *expectedSampledCmds0.begin(); + + auto [originalCmd1, expectedSampledCmds1] = makeFindAndModifyCommandRequest( + nss1, + false /* isUpdate */, + true /* markForSampling */, + std::string(maxMemoryUsageBytes / 2, 'b') /* filterFieldName */); + ASSERT_EQ(expectedSampledCmds0.size(), 1U); + auto [sampleId1, expectedSampledCmd1] = *expectedSampledCmds1.begin(); + + writer.addFindAndModifyQuery(originalCmd0).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 1); + // Adding the next query causes the size to exceed the limit. + writer.addFindAndModifyQuery(originalCmd1).get(); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1); + assertSampledWriteQueryDocument(sampleId0, + expectedSampledCmd0.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd0); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1); + assertSampledWriteQueryDocument(sampleId1, + expectedSampledCmd1.getNamespace(), + SampledWriteCommandNameEnum::kFindAndModify, + expectedSampledCmd1); +} + +TEST_F(QueryAnalysisWriterTest, AddQueriesBackAfterWriteError) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto originalFilter = makeNonEmptyFilter(); + auto originalCollation = makeNonEmptyCollation(); + auto numQueries = 8; + + std::vector<UUID> sampleIds0; + for (auto i = 0; i < numQueries; i++) { + sampleIds0.push_back(UUID::gen()); + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries); + + // Force the documents to get inserted in three batches of size 3, 3 and 2, respectively. + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 3}; + + // Hang after inserting the documents in the first batch. + auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts"); + auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0); + + auto future = stdx::async(stdx::launch::async, [&] { + ThreadClient tc(getServiceContext()); + auto opCtx = makeOperationContext(); + writer.flushQueriesForTest(opCtx.get()); + }); + + hangFp->waitForTimesEntered(hangTimesEntered + 1); + // Force the second batch to fail so that it falls back to inserting one document at a time in + // order, and then force the first and second document in the batch to fail. + auto failFp = globalFailPointRegistry().find("failCollectionInserts"); + failFp->setMode(FailPoint::nTimes, 3); + hangFp->setMode(FailPoint::off, 0); + + future.get(); + // Verify that all the documents other than the ones in the first batch got added back to the + // buffer after the error. That is, the error caused the last document in the second batch to + // get added to buffer also although it was successfully inserted since the writer did not have + // a way to tell if the error caused the entire command to fail early. + ASSERT_EQ(writer.getQueriesCountForTest(), 5); + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 4); + + // Flush that remaining documents. If the documents were not added back correctly, some + // documents would be missing and the checks below would fail. + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries); + for (const auto& sampleId : sampleIds0) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } +} + +TEST_F(QueryAnalysisWriterTest, RemoveDuplicatesFromBufferAfterWriteError) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto originalFilter = makeNonEmptyFilter(); + auto originalCollation = makeNonEmptyCollation(); + + auto numQueries0 = 3; + + std::vector<UUID> sampleIds0; + for (auto i = 0; i < numQueries0; i++) { + sampleIds0.push_back(UUID::gen()); + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries0); + for (const auto& sampleId : sampleIds0) { + assertSampledReadQueryDocument( + sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } + + auto numQueries1 = 5; + + std::vector<UUID> sampleIds1; + for (auto i = 0; i < numQueries1; i++) { + sampleIds1.push_back(UUID::gen()); + writer.addFindQuery(sampleIds1[i], nss1, originalFilter, originalCollation).get(); + // This is a duplicate. + if (i < numQueries0) { + writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get(); + } + } + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0 + numQueries1); + + // Force the batch to fail so that it falls back to inserting one document at a time in order. + auto failFp = globalFailPointRegistry().find("failCollectionInserts"); + failFp->setMode(FailPoint::nTimes, 1); + + // Hang after inserting the first non-duplicate document. + auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts"); + auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0); + + auto future = stdx::async(stdx::launch::async, [&] { + ThreadClient tc(getServiceContext()); + auto opCtx = makeOperationContext(); + writer.flushQueriesForTest(opCtx.get()); + }); + + hangFp->waitForTimesEntered(hangTimesEntered + 1); + // Force the next non-duplicate document to fail to insert. + failFp->setMode(FailPoint::nTimes, 1); + hangFp->setMode(FailPoint::off, 0); + + future.get(); + // Verify that the duplicate documents did not get added back to the buffer after the error. + ASSERT_EQ(writer.getQueriesCountForTest(), numQueries1); + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1 - 1); + + // Flush that remaining documents. If the documents were not added back correctly, the document + // that previously failed to insert would be missing and the checks below would fail. + failFp->setMode(FailPoint::off, 0); + writer.flushQueriesForTest(operationContext()); + ASSERT_EQ(writer.getQueriesCountForTest(), 0); + + ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1); + for (const auto& sampleId : sampleIds1) { + assertSampledReadQueryDocument( + sampleId, nss1, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation); + } +} + +TEST_F(QueryAnalysisWriterTest, NoDiffs) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + writer.flushQueriesForTest(operationContext()); +} + +TEST_F(QueryAnalysisWriterTest, DiffsBasic) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON("a" << 1); + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId, nss0, *doc_diff::computeInlineDiff(preImage, postImage)); +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleQueriesAndCollections) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + // Make nss0 have a diff for one query. + auto collUuid0 = getCollectionUUID(nss0); + + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0 << "b" << 0 << "c" << 0); + auto postImage0 = BSON("a" << 0 << "b" << 1 << "d" << 1); + + // Make nss1 have diffs for two queries. + auto collUuid1 = getCollectionUUID(nss1); + + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1) << "d" << BSON("e" << 1)); + auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2) << "d" << BSON("e" << 2)); + + auto sampleId2 = UUID::gen(); + auto preImage2 = BSON("a" << BSONObj()); + auto postImage2 = BSON("a" << BSON("b" << 2)); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get(); + writer.addDiff(sampleId2, nss1, collUuid1, preImage2, postImage2).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 3); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + + ASSERT_EQ(getDiffDocumentsCount(nss1), 2); + assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1)); + assertDiffDocument(sampleId2, nss1, *doc_diff::computeInlineDiff(preImage2, postImage2)); +} + +TEST_F(QueryAnalysisWriterTest, DuplicateDiffs) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0); + auto postImage0 = BSON("a" << 1); + + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1)); + auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2)); + + auto sampleId2 = UUID::gen(); + auto preImage2 = BSON("a" << BSONObj()); + auto postImage2 = BSON("a" << BSON("b" << 2)); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + + writer.addDiff(sampleId1, nss0, collUuid0, preImage1, postImage1).get(); + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0) + .get(); // This is a duplicate. + writer.addDiff(sampleId2, nss0, collUuid0, preImage2, postImage2).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 3); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 3); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + assertDiffDocument(sampleId1, nss0, *doc_diff::computeInlineDiff(preImage1, postImage1)); + assertDiffDocument(sampleId2, nss0, *doc_diff::computeInlineDiff(preImage2, postImage2)); +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBatchSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2}; + auto numDiffs = 5; + auto collUuid0 = getCollectionUUID(nss0); + + std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs; + for (auto i = 0; i < numDiffs; i++) { + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON(("a" + std::to_string(i)) << 1); + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + expectedSampledDiffs.push_back( + {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)}); + } + ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs); + for (const auto& [sampleId, diff] : expectedSampledDiffs) { + assertDiffDocument(sampleId, nss0, diff); + } +} + +TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBSONObjSize) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto numDiffs = 3; + auto collUuid0 = getCollectionUUID(nss0); + + std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs; + for (auto i = 0; i < numDiffs; i++) { + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 0); + auto postImage = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1); + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + expectedSampledDiffs.push_back( + {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)}); + } + ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs); + for (const auto& [sampleId, diff] : expectedSampledDiffs) { + assertDiffDocument(sampleId, nss0, diff); + } +} + +TEST_F(QueryAnalysisWriterTest, FlushAfterAddDiffIfExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto maxMemoryUsageBytes = 1024; + RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes", + maxMemoryUsageBytes}; + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId0 = UUID::gen(); + auto preImage0 = BSON("a" << 0); + auto postImage0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1); + + auto collUuid1 = getCollectionUUID(nss1); + auto sampleId1 = UUID::gen(); + auto preImage1 = BSON("a" << 0); + auto postImage1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1); + + writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 1); + // Adding the next diff causes the size to exceed the limit. + writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 1); + assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0)); + ASSERT_EQ(getDiffDocumentsCount(nss1), 1); + assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1)); +} + +TEST_F(QueryAnalysisWriterTest, DiffEmpty) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON("a" << 1); + auto postImage = preImage; + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 0); +} + +TEST_F(QueryAnalysisWriterTest, DiffExceedsSizeLimit) { + auto& writer = QueryAnalysisWriter::get(operationContext()); + + auto collUuid0 = getCollectionUUID(nss0); + auto sampleId = UUID::gen(); + auto preImage = BSON(std::string(BSONObjMaxUserSize, 'a') << 1); + auto postImage = BSONObj(); + + writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get(); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + writer.flushDiffsForTest(operationContext()); + ASSERT_EQ(writer.getDiffsCountForTest(), 0); + + ASSERT_EQ(getDiffDocumentsCount(nss0), 0); +} + +} // namespace +} // namespace analyze_shard_key +} // namespace mongo diff --git a/src/mongo/db/s/range_deleter_service.cpp b/src/mongo/db/s/range_deleter_service.cpp index acabcc40bea..ca80b24aa3e 100644 --- a/src/mongo/db/s/range_deleter_service.cpp +++ b/src/mongo/db/s/range_deleter_service.cpp @@ -86,6 +86,50 @@ RangeDeleterService* RangeDeleterService::get(OperationContext* opCtx) { return get(opCtx->getServiceContext()); } +RangeDeleterService::ReadyRangeDeletionsProcessor::ReadyRangeDeletionsProcessor( + OperationContext* opCtx) + : _thread([this] { _runRangeDeletions(); }) {} + +RangeDeleterService::ReadyRangeDeletionsProcessor::~ReadyRangeDeletionsProcessor() { + shutdown(); + invariant(_thread.joinable()); + _thread.join(); + invariant(!_threadOpCtxHolder, + "Thread operation context is still alive after joining main thread"); +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::shutdown() { + stdx::lock_guard<Latch> lock(_mutex); + if (_state == kStopped) + return; + + _state = kStopped; + + if (_threadOpCtxHolder) { + stdx::lock_guard<Client> scopedClientLock(*_threadOpCtxHolder->getClient()); + _threadOpCtxHolder->markKilled(ErrorCodes::Interrupted); + } +} + +bool RangeDeleterService::ReadyRangeDeletionsProcessor::_stopRequested() const { + stdx::unique_lock<Latch> lock(_mutex); + return _state == kStopped; +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::emplaceRangeDeletion( + const RangeDeletionTask& rdt) { + stdx::unique_lock<Latch> lock(_mutex); + invariant(_state == kRunning); + _queue.push(rdt); + _condVar.notify_all(); +} + +void RangeDeleterService::ReadyRangeDeletionsProcessor::_completedRangeDeletion() { + stdx::unique_lock<Latch> lock(_mutex); + dassert(!_queue.empty()); + _queue.pop(); +} + void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { Client::initThread(kRangeDeletionThreadName); { @@ -93,21 +137,22 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { cc().setSystemOperationKillableByStepdown(lk); } - auto opCtx = [&] { - stdx::unique_lock<Latch> lock(_mutex); + { + stdx::lock_guard<Latch> lock(_mutex); + if (_state != kRunning) { + return; + } _threadOpCtxHolder = cc().makeOperationContext(); - _condVar.notify_all(); - return (*_threadOpCtxHolder).get(); - }(); + } - opCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE(); + auto opCtx = _threadOpCtxHolder.get(); ON_BLOCK_EXIT([this]() { - stdx::unique_lock<Latch> lock(_mutex); + stdx::lock_guard<Latch> lock(_mutex); _threadOpCtxHolder.reset(); }); - while (opCtx->checkForInterruptNoAssert().isOK()) { + while (!_stopRequested()) { { stdx::unique_lock<Latch> lock(_mutex); try { @@ -223,7 +268,7 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { // Release the thread only in case the operation context has been interrupted, as // interruption only happens on shutdown/stepdown (this is fine because range // deletions will be resumed on the next step up) - if (!opCtx->checkForInterruptNoAssert().isOK()) { + if (_stopRequested()) { break; } @@ -237,6 +282,17 @@ void RangeDeleterService::ReadyRangeDeletionsProcessor::_runRangeDeletions() { } } +void RangeDeleterService::onStartup(OperationContext* opCtx) { + if (disableResumableRangeDeleter.load() || + !feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { + return; + } + + auto opObserverRegistry = + checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver()); + opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>()); +} + void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long term) { if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { return; @@ -249,22 +305,14 @@ void RangeDeleterService::onStepUpComplete(OperationContext* opCtx, long long te return; } + // Wait until all tasks and thread from previous term drain + _joinAndResetState(); + auto lock = _acquireMutexUnconditionally(); dassert(_state == kDown, "Service expected to be down before stepping up"); _state = kInitializing; - if (_executor) { - // Join previously shutted down executor before reinstantiating it - _executor->join(); - _executor.reset(); - } else { - // Initializing the op observer, only executed once at the first step-up - auto opObserverRegistry = - checked_cast<OpObserverRegistry*>(opCtx->getServiceContext()->getOpObserver()); - opObserverRegistry->addObserver(std::make_unique<RangeDeleterServiceOpObserver>()); - } - const std::string kExecName("RangeDeleterServiceExecutor"); auto net = executor::makeNetworkInterface(kExecName); auto pool = std::make_unique<executor::NetworkInterfaceThreadPool>(net.get()); @@ -303,9 +351,14 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx LOGV2(6834800, "Resubmitting range deletion tasks"); - ScopedRangeDeleterLock rangeDeleterLock(opCtx); - DBDirectClient client(opCtx); + // The Scoped lock is needed to serialize with concurrent range deletions + ScopedRangeDeleterLock rangeDeleterLock(opCtx, MODE_S); + // The collection lock is needed to serialize with donors trying to + // schedule local range deletions by updating the 'pending' field + AutoGetCollection rangeDeletionLock( + opCtx, NamespaceString::kRangeDeletionNamespace, MODE_S); + DBDirectClient client(opCtx); int nRescheduledTasks = 0; // (1) register range deletion tasks marked as "processing" @@ -370,47 +423,55 @@ void RangeDeleterService::_recoverRangeDeletionsOnStepUp(OperationContext* opCtx .semi(); } -void RangeDeleterService::_stopService(bool joinExecutor) { - if (!feature_flags::gRangeDeleterService.isEnabledAndIgnoreFCV()) { - return; - } +void RangeDeleterService::_joinAndResetState() { + invariant(_state == kDown); + // Join the thread spawned on step-up to resume range deletions + _stepUpCompletedFuture.getNoThrow().ignore(); - { - auto lock = _acquireMutexUnconditionally(); - _state = kDown; - if (_initOpCtxHolder) { - stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient()); - _initOpCtxHolder->markKilled(ErrorCodes::Interrupted); - } + // Join and destruct the executor + if (_executor) { + _executor->join(); + _executor.reset(); } - // Join the thread spawned on step-up to resume range deletions - _stepUpCompletedFuture.getNoThrow().ignore(); + // Join and destruct the processor + _readyRangeDeletionsProcessorPtr.reset(); + + // Clear range deletions potentially created during recovery + _rangeDeletionTasks.clear(); +} +void RangeDeleterService::_stopService() { auto lock = _acquireMutexUnconditionally(); + if (_state == kDown) + return; + + _state = kDown; + if (_initOpCtxHolder) { + stdx::lock_guard<Client> lk(*_initOpCtxHolder->getClient()); + _initOpCtxHolder->markKilled(ErrorCodes::Interrupted); + } - // It may happen for the `onStepDown` hook to be invoked on a SECONDARY node transitioning - // to ROLLBACK, hence the executor may have never been initialized if (_executor) { _executor->shutdown(); - if (joinExecutor) { - _executor->join(); - } } - // Destroy the range deletion processor in order to stop range deletions - _readyRangeDeletionsProcessorPtr.reset(); + // Shutdown the range deletion processor to interrupt range deletions + if (_readyRangeDeletionsProcessorPtr) { + _readyRangeDeletionsProcessorPtr->shutdown(); + } // Clear range deletion tasks map in order to notify potential waiters on completion futures _rangeDeletionTasks.clear(); } void RangeDeleterService::onStepDown() { - _stopService(false /* joinExecutor */); + _stopService(); } void RangeDeleterService::onShutdown() { - _stopService(true /* joinExecutor */); + _stopService(); + _joinAndResetState(); } BSONObj RangeDeleterService::dumpState() { @@ -472,10 +533,9 @@ SharedSemiFuture<void> RangeDeleterService::registerTask( .then([this, rdt = rdt]() { // Step 3: schedule the actual range deletion task auto lock = _acquireMutexUnconditionally(); - invariant( - _readyRangeDeletionsProcessorPtr || _state == kDown, - "The range deletions processor must be instantiated if the state != kDown"); if (_state != kDown) { + invariant(_readyRangeDeletionsProcessorPtr, + "The range deletions processor is not initialized"); _readyRangeDeletionsProcessorPtr->emplaceRangeDeletion(rdt); } }); diff --git a/src/mongo/db/s/range_deleter_service.h b/src/mongo/db/s/range_deleter_service.h index 68cca97454c..c816c8b9db2 100644 --- a/src/mongo/db/s/range_deleter_service.h +++ b/src/mongo/db/s/range_deleter_service.h @@ -110,58 +110,37 @@ private: */ class ReadyRangeDeletionsProcessor { public: - ReadyRangeDeletionsProcessor(OperationContext* opCtx) { - _thread = stdx::thread([this] { _runRangeDeletions(); }); - stdx::unique_lock<Latch> lock(_mutex); - opCtx->waitForConditionOrInterrupt( - _condVar, lock, [&] { return _threadOpCtxHolder.is_initialized(); }); - } - - ~ReadyRangeDeletionsProcessor() { - { - stdx::unique_lock<Latch> lock(_mutex); - // The `_threadOpCtxHolder` may have been already reset/interrupted in case the - // thread got interrupted due to stepdown - if (_threadOpCtxHolder) { - stdx::lock_guard<Client> scopedClientLock(*(*_threadOpCtxHolder)->getClient()); - if ((*_threadOpCtxHolder)->checkForInterruptNoAssert().isOK()) { - (*_threadOpCtxHolder)->markKilled(ErrorCodes::Interrupted); - } - } - _condVar.notify_all(); - } + ReadyRangeDeletionsProcessor(OperationContext* opCtx); + ~ReadyRangeDeletionsProcessor(); - if (_thread.joinable()) { - _thread.join(); - } - } + /* + * Interrupt ongoing range deletions + */ + void shutdown(); /* * Schedule a range deletion at the end of the queue */ - void emplaceRangeDeletion(const RangeDeletionTask& rdt) { - stdx::unique_lock<Latch> lock(_mutex); - _queue.push(rdt); - _condVar.notify_all(); - } + void emplaceRangeDeletion(const RangeDeletionTask& rdt); private: /* + * Return true if this processor have been shutted down + */ + bool _stopRequested() const; + + /* * Remove a range deletion from the head of the queue. Supposed to be called only once a * range deletion successfully finishes. */ - void _completedRangeDeletion() { - stdx::unique_lock<Latch> lock(_mutex); - dassert(!_queue.empty()); - _queue.pop(); - } + void _completedRangeDeletion(); /* * Code executed by the internal thread */ void _runRangeDeletions(); - Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor"); + mutable Mutex _mutex = MONGO_MAKE_LATCH("ReadyRangeDeletionsProcessor"); /* * Condition variable notified when: @@ -175,10 +154,13 @@ private: std::queue<RangeDeletionTask> _queue; /* Pointer to the (one and only) operation context used by the thread */ - boost::optional<ServiceContext::UniqueOperationContext> _threadOpCtxHolder; + ServiceContext::UniqueOperationContext _threadOpCtxHolder; /* Thread consuming the range deletions queue */ stdx::thread _thread; + + enum State { kRunning, kStopped }; + State _state{kRunning}; }; // Keeping track of per-collection registered range deletion tasks @@ -253,6 +235,7 @@ public: const ChunkRange& range); /* ReplicaSetAwareServiceShardSvr implemented methods */ + void onStartup(OperationContext* opCtx) override; void onStepUpComplete(OperationContext* opCtx, long long term) override; void onStepDown() override; void onShutdown() override; @@ -276,17 +259,21 @@ public: std::unique_ptr<ReadyRangeDeletionsProcessor> _readyRangeDeletionsProcessorPtr; private: + /* Join all threads and executor and reset the in memory state of the service + * Used for onStartUpBegin and on onShutdown + */ + void _joinAndResetState(); + /* Asynchronously register range deletions on the service. To be called on on step-up */ void _recoverRangeDeletionsOnStepUp(OperationContext* opCtx); - /* Called by shutdown/stepdown hooks to reset the service */ - void _stopService(bool joinExecutor); + /* Called by shutdown/stepdown hooks to interrupt the service */ + void _stopService(); /* ReplicaSetAwareServiceShardSvr "empty implemented" methods */ - void onStartup(OperationContext* opCtx) override final{}; void onInitialDataAvailable(OperationContext* opCtx, bool isMajorityDataAvailable) override final {} - void onStepUpBegin(OperationContext* opCtx, long long term) override final {} + void onStepUpBegin(OperationContext* opCtx, long long term) override final{}; void onBecomeArbiter() override final {} }; diff --git a/src/mongo/db/s/range_deleter_service_op_observer.cpp b/src/mongo/db/s/range_deleter_service_op_observer.cpp index 9c1e51f9a6f..10bbbb6a924 100644 --- a/src/mongo/db/s/range_deleter_service_op_observer.cpp +++ b/src/mongo/db/s/range_deleter_service_op_observer.cpp @@ -35,6 +35,9 @@ #include "mongo/db/s/range_deleter_service.h" #include "mongo/db/s/range_deletion_task_gen.h" #include "mongo/db/update/update_oplog_entry_serialization.h" +#include "mongo/logv2/log.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kShardingRangeDeleter namespace mongo { namespace { @@ -54,10 +57,15 @@ void registerTaskWithOngoingQueriesOnOpLogEntryCommit(OperationContext* opCtx, (void)RangeDeleterService::get(opCtx)->registerTask( rdt, std::move(waitForActiveQueriesToComplete)); } catch (const DBException& ex) { - dassert(ex.code() == ErrorCodes::NotYetInitialized, - str::stream() << "No error different from `NotYetInitialized` is expected " - "to be propagated to the range deleter observer. Got error: " - << ex.toStatus()); + if (ex.code() != ErrorCodes::NotYetInitialized && + !ErrorCodes::isA<ErrorCategory::NotPrimaryError>(ex.code())) { + LOGV2_WARNING(7092800, + "No error different from `NotYetInitialized` or `NotPrimaryError` " + "category is expected to be propagated to the range deleter " + "observer. Range deletion task not registered.", + "error"_attr = redact(ex), + "task"_attr = rdt); + } } }); } diff --git a/src/mongo/db/s/range_deleter_service_test.cpp b/src/mongo/db/s/range_deleter_service_test.cpp index d2406120483..0807867da2d 100644 --- a/src/mongo/db/s/range_deleter_service_test.cpp +++ b/src/mongo/db/s/range_deleter_service_test.cpp @@ -45,6 +45,7 @@ void RangeDeleterServiceTest::setUp() { ShardServerTestFixture::setUp(); WaitForMajorityService::get(getServiceContext()).startup(getServiceContext()); opCtx = operationContext(); + RangeDeleterService::get(opCtx)->onStartup(opCtx); RangeDeleterService::get(opCtx)->onStepUpComplete(opCtx, 0L); RangeDeleterService::get(opCtx)->_waitForRangeDeleterServiceUp_FOR_TESTING(); @@ -275,12 +276,6 @@ TEST_F(RangeDeleterServiceTest, ScheduledTaskInvalidatedOnStepDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - // Re-enable the service for clean teardown - rds->onStepUpComplete(opCtx, 0L); - rds->_waitForRangeDeleterServiceUp_FOR_TESTING(); - }); - try { completionFuture.get(opCtx); } catch (const ExceptionForCat<ErrorCategory::Interruption>&) { @@ -293,12 +288,6 @@ TEST_F(RangeDeleterServiceTest, NoActionPossibleIfServiceIsDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - // Re-enable the service for clean teardown - rds->onStepUpComplete(opCtx, 0L); - rds->_waitForRangeDeleterServiceUp_FOR_TESTING(); - }); - auto taskWithOngoingQueries = createRangeDeletionTaskWithOngoingQueries( uuidCollA, BSON(kShardKey << 0), BSON(kShardKey << 10), CleanWhenEnum::kDelayed); @@ -886,10 +875,6 @@ TEST_F(RangeDeleterServiceTest, WaitForOngoingQueriesInvalidatedOnStepDown) { // Manually trigger disabling of the service rds->onStepDown(); - ON_BLOCK_EXIT([&] { - rds->onStepUpComplete(opCtx, 0L); // Re-enable the service - }); - try { completionFuture.get(opCtx); } catch (const ExceptionForCat<ErrorCategory::Interruption>&) { diff --git a/src/mongo/db/s/range_deletion_util.cpp b/src/mongo/db/s/range_deletion_util.cpp index 3a8b512ddba..707516db64f 100644 --- a/src/mongo/db/s/range_deletion_util.cpp +++ b/src/mongo/db/s/range_deletion_util.cpp @@ -322,7 +322,7 @@ Status deleteRangeInBatches(OperationContext* opCtx, const ChunkRange& range) { suspendRangeDeletion.pauseWhileSet(opCtx); - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); bool allDocsRemoved = false; // Delete all batches in this range unless a stepdown error occurs. Do not yield the @@ -584,7 +584,7 @@ void persistUpdatedNumOrphans(OperationContext* opCtx, const auto query = getQueryFilterForRangeDeletionTask(collectionUuid, range); try { PersistentTaskStore<RangeDeletionTask> store(NamespaceString::kRangeDeletionNamespace); - ScopedRangeDeleterLock rangeDeleterLock(opCtx, collectionUuid); + ScopedRangeDeleterLock rangeDeleterLock(opCtx, LockMode::MODE_IX); // The DBDirectClient will not retry WriteConflictExceptions internally while holding an X // mode lock, so we need to retry at this level. writeConflictRetry( diff --git a/src/mongo/db/s/shard_server_op_observer.cpp b/src/mongo/db/s/shard_server_op_observer.cpp index 744c79a1eca..ae059d19512 100644 --- a/src/mongo/db/s/shard_server_op_observer.cpp +++ b/src/mongo/db/s/shard_server_op_observer.cpp @@ -509,20 +509,24 @@ void ShardServerOpObserver::onModifyShardedCollectionGlobalIndexCatalogEntry( IDLParserContext("onModifyShardedCollectionGlobalIndexCatalogEntry"), indexDoc["entry"].Obj()); auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp(); - opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry](auto _) { + auto uuid = uassertStatusOK( + UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName])); + opCtx->recoveryUnit()->onCommit([opCtx, nss, indexVersion, indexEntry, uuid](auto _) { AutoGetCollection autoColl(opCtx, nss, MODE_IX); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, nss, CSRAcquisitionMode::kExclusive) - ->addIndex(opCtx, indexEntry, indexVersion); + ->addIndex(opCtx, indexEntry, {uuid, indexVersion}); }); } else { auto indexName = indexDoc["entry"][IndexCatalogType::kNameFieldName].str(); auto indexVersion = indexDoc["entry"][IndexCatalogType::kLastmodFieldName].timestamp(); - opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion](auto _) { + auto uuid = uassertStatusOK( + UUID::parse(indexDoc["entry"][IndexCatalogType::kCollectionUUIDFieldName])); + opCtx->recoveryUnit()->onCommit([opCtx, nss, indexName, indexVersion, uuid](auto _) { AutoGetCollection autoColl(opCtx, nss, MODE_IX); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, nss, CSRAcquisitionMode::kExclusive) - ->removeIndex(opCtx, indexName, indexVersion); + ->removeIndex(opCtx, indexName, {uuid, indexVersion}); }); } } diff --git a/src/mongo/db/s/sharding_recovery_service.cpp b/src/mongo/db/s/sharding_recovery_service.cpp index d47e4056534..02dab8f3773 100644 --- a/src/mongo/db/s/sharding_recovery_service.cpp +++ b/src/mongo/db/s/sharding_recovery_service.cpp @@ -532,7 +532,7 @@ void ShardingRecoveryService::recoverIndexesCatalog(OperationContext* opCtx) { AutoGetCollection collLock(opCtx, nss, MODE_X); CollectionShardingRuntime::assertCollectionLockedAndAcquire( opCtx, collLock->ns(), CSRAcquisitionMode::kExclusive) - ->addIndex(opCtx, indexEntry, indexVersion); + ->addIndex(opCtx, indexEntry, {indexEntry.getCollectionUUID(), indexVersion}); } } LOGV2_DEBUG(6686502, 2, "Recovered all index versions"); diff --git a/src/mongo/db/s/sharding_write_router_bm.cpp b/src/mongo/db/s/sharding_write_router_bm.cpp index e92d8292052..89fa3ded04d 100644 --- a/src/mongo/db/s/sharding_write_router_bm.cpp +++ b/src/mongo/db/s/sharding_write_router_bm.cpp @@ -38,6 +38,7 @@ #include "mongo/db/s/collection_metadata.h" #include "mongo/db/s/collection_sharding_runtime.h" #include "mongo/db/s/collection_sharding_state_factory_shard.h" +#include "mongo/db/s/collection_sharding_state_factory_standalone.h" #include "mongo/db/s/operation_sharding_state.h" #include "mongo/db/s/sharding_state.h" #include "mongo/db/s/sharding_write_router.h" @@ -214,6 +215,10 @@ void BM_UnshardedDestinedRecipient(benchmark::State& state) { serviceContext->registerClientObserver(std::make_unique<LockerNoopClientObserver>()); const auto opCtx = client->makeOperationContext(); + CollectionShardingStateFactory::set( + opCtx->getServiceContext(), + std::make_unique<CollectionShardingStateFactoryStandalone>(opCtx->getServiceContext())); + const auto catalogCache = CatalogCacheMock::make(); for (auto keepRunning : state) { diff --git a/src/mongo/db/s/transaction_coordinator_util.cpp b/src/mongo/db/s/transaction_coordinator_util.cpp index 736e5fb1cf4..9bd3114a84f 100644 --- a/src/mongo/db/s/transaction_coordinator_util.cpp +++ b/src/mongo/db/s/transaction_coordinator_util.cpp @@ -445,7 +445,7 @@ Future<repl::OpTime> persistDecision(txn::AsyncWorkScheduler& scheduler, // Do not acquire a storage ticket in order to avoid unnecessary serialization // with other prepared transactions that are holding a storage ticket // themselves; see SERVER-60682. - SetTicketAquisitionPriorityForLock setTicketAquisition( + SetAdmissionPriorityForLock setTicketAquisition( opCtx, AdmissionContext::Priority::kImmediate); getTransactionCoordinatorWorkerCurOpRepository()->set( opCtx, lsid, txnNumberAndRetryCounter, CoordinatorAction::kWritingDecision); diff --git a/src/mongo/db/server_feature_flags.idl b/src/mongo/db/server_feature_flags.idl index b391bb5a749..c02079f3f37 100644 --- a/src/mongo/db/server_feature_flags.idl +++ b/src/mongo/db/server_feature_flags.idl @@ -62,4 +62,7 @@ feature_flags: cpp_varname: gFeatureFlagUseNewCompactStructuredEncryptionDataCoordinator default: true version: 6.1 - + featureFlagOIDC: + description: "Feature flag for OIDC support" + cpp_varname: gFeatureFlagOIDC + default: false diff --git a/src/mongo/db/service_context.cpp b/src/mongo/db/service_context.cpp index 013e3209f95..61e20dc82e2 100644 --- a/src/mongo/db/service_context.cpp +++ b/src/mongo/db/service_context.cpp @@ -100,7 +100,7 @@ void setGlobalServiceContext(ServiceContext::UniqueServiceContext&& serviceConte ServiceContext::ServiceContext() : _opIdRegistry(UniqueOperationIdRegistry::create()), - _tickSource(std::make_unique<SystemTickSource>()), + _tickSource(makeSystemTickSource()), _fastClockSource(std::make_unique<SystemClockSource>()), _preciseClockSource(std::make_unique<SystemClockSource>()) {} diff --git a/src/mongo/db/service_entry_point_common.cpp b/src/mongo/db/service_entry_point_common.cpp index 0c4766739e4..84f46d900dd 100644 --- a/src/mongo/db/service_entry_point_common.cpp +++ b/src/mongo/db/service_entry_point_common.cpp @@ -2105,7 +2105,20 @@ DbResponse makeCommandResponse(std::shared_ptr<HandleRequest::ExecutionContext> } } - dbResponse.response = replyBuilder->done(); + try { + dbResponse.response = replyBuilder->done(); + } catch (const ExceptionFor<ErrorCodes::BSONObjectTooLarge>& ex) { + // Create a new reply builder as subsequently calling any methods on a builder after + // 'done()' results in undefined behavior. + auto errorReplyBuilder = execContext->getReplyBuilder(); + BSONObjBuilder metadataBob; + BSONObjBuilder extraFieldsBuilder; + appendClusterAndOperationTime( + opCtx, &extraFieldsBuilder, &metadataBob, LogicalTime::kUninitialized); + generateErrorResponse( + opCtx, errorReplyBuilder, ex.toStatus(), metadataBob.obj(), extraFieldsBuilder.obj()); + dbResponse.response = errorReplyBuilder->done(); + } CurOp::get(opCtx)->debug().responseLength = dbResponse.response.header().dataLen(); return dbResponse; @@ -2347,10 +2360,26 @@ BSONObj ServiceEntryPointCommon::getRedactedCopyForLogging(const Command* comman return bob.obj(); } -void onHandleRequestException(const Status& status) { +void logHandleRequestFailure(const Status& status) { LOGV2_ERROR(4879802, "Failed to handle request", "error"_attr = redact(status)); } +void onHandleRequestException(const HandleRequest& hr, const Status& status) { + auto isMirrorOp = [&] { + const auto& obj = hr.executionContext->getRequest().body; + if (auto e = obj.getField("mirrored"); MONGO_unlikely(e.ok() && e.boolean())) + return true; + return false; + }; + + // TODO SERVER-70510 revert changes introduced by SERVER-60553 that suppresses errors occurred + // during handling of mirroring operations on recovering secondaries. + if (MONGO_unlikely(status == ErrorCodes::NotWritablePrimary && isMirrorOp())) + return; + + logHandleRequestFailure(status); +} + Future<DbResponse> ServiceEntryPointCommon::handleRequest( OperationContext* opCtx, const Message& m, @@ -2362,7 +2391,7 @@ Future<DbResponse> ServiceEntryPointCommon::handleRequest( invariant(opRunner); return opRunner->run() - .then([hr = std::move(hr)](DbResponse response) mutable { + .then([&hr](DbResponse response) mutable { hr.completeOperation(response); auto opCtx = hr.executionContext->getOpCtx(); @@ -2377,10 +2406,10 @@ Future<DbResponse> ServiceEntryPointCommon::handleRequest( return response; }) - .tapError([](Status status) { onHandleRequestException(status); }); + .tapError([hr = std::move(hr)](Status status) { onHandleRequestException(hr, status); }); } catch (const DBException& ex) { auto status = ex.toStatus(); - onHandleRequestException(status); + logHandleRequestFailure(status); return status; } diff --git a/src/mongo/db/stats/SConscript b/src/mongo/db/stats/SConscript index 784019b4532..f7745f59c9a 100644 --- a/src/mongo/db/stats/SConscript +++ b/src/mongo/db/stats/SConscript @@ -128,6 +128,7 @@ env.Library( '$BUILD_DIR/mongo/db/commands/server_status_core', '$BUILD_DIR/mongo/db/index/index_access_method', '$BUILD_DIR/mongo/db/pipeline/document_sources_idl', + '$BUILD_DIR/mongo/db/s/balancer_stats_registry', '$BUILD_DIR/mongo/db/server_base', '$BUILD_DIR/mongo/db/shard_role', '$BUILD_DIR/mongo/db/timeseries/bucket_catalog', diff --git a/src/mongo/db/storage/external_record_store.cpp b/src/mongo/db/storage/external_record_store.cpp index e63286fbf6a..2a4bb9a8796 100644 --- a/src/mongo/db/storage/external_record_store.cpp +++ b/src/mongo/db/storage/external_record_store.cpp @@ -33,8 +33,10 @@ #include "mongo/db/storage/record_store.h" namespace mongo { +// 'ident' is an identifer to WT table and a virtual collection does not have any persistent data +// in WT. So, we set the "dummy" ident for a virtual collection. ExternalRecordStore::ExternalRecordStore(StringData ns, const VirtualCollectionOptions& vopts) - : RecordStore(ns, /*identName=*/ns, /*isCapped=*/false), _vopts(vopts) {} + : RecordStore(ns, /*identName=*/"dummy"_sd, /*isCapped=*/false), _vopts(vopts) {} /** * Returns a MultiBsonStreamCursor for this record store. Reverse scans are not currently supported @@ -42,7 +44,6 @@ ExternalRecordStore::ExternalRecordStore(StringData ns, const VirtualCollectionO */ std::unique_ptr<SeekableRecordCursor> ExternalRecordStore::getCursor(OperationContext* opCtx, bool forward) const { - if (forward) { return std::make_unique<MultiBsonStreamCursor>(getOptions()); } diff --git a/src/mongo/db/storage/external_record_store.h b/src/mongo/db/storage/external_record_store.h index 152d766e6cb..ac0c61392c6 100644 --- a/src/mongo/db/storage/external_record_store.h +++ b/src/mongo/db/storage/external_record_store.h @@ -33,16 +33,8 @@ #include "mongo/db/catalog/virtual_collection_options.h" #include "mongo/db/storage/record_store.h" #include "mongo/util/assert_util.h" -#include "mongo/util/errno_util.h" namespace mongo { -namespace { -inline std::string getErrorMessage(StringData op, const std::string& path) { - using namespace fmt::literals; - return "Failed to {} {}: {}"_format(op, path, errorMessage(lastSystemError())); -} -} // namespace - class ExternalRecordStore : public RecordStore { public: ExternalRecordStore(StringData ns, const VirtualCollectionOptions& vopts); diff --git a/src/mongo/db/storage/external_record_store_test.cpp b/src/mongo/db/storage/external_record_store_test.cpp index f4c42fb03ad..e5474f43133 100644 --- a/src/mongo/db/storage/external_record_store_test.cpp +++ b/src/mongo/db/storage/external_record_store_test.cpp @@ -32,11 +32,9 @@ #include <ctime> #include <fmt/format.h> -#include "mongo/db/storage/external_record_store.h" #include "mongo/db/storage/input_stream.h" #include "mongo/db/storage/multi_bson_stream_cursor.h" #include "mongo/db/storage/named_pipe.h" -#include "mongo/logv2/log.h" #include "mongo/stdx/thread.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -45,21 +43,33 @@ namespace mongo { using namespace fmt::literals; -#ifndef _WIN32 -static const std::string pipePath1 = "/tmp/named_pipe1"; -static const std::string pipePath2 = "/tmp/named_pipe2"; -static const std::string nonExistingPath = "/tmp/non-existing"; -#else -// "//./pipe" is the required path start of all named pipes on Windows, where "//." is the -// abbreviation for the local server name and "/pipe" is a literal. (These also work with -// Windows-native backslashes instead of forward slashes.) -static const std::string pipePath1 = R"(//./pipe/named_pipe1)"; -static const std::string pipePath2 = R"(//./pipe/named_pipe2)"; -static const std::string nonExistingPath = R"(//./pipe/non-existing)"; -#endif +static const std::string pipePath1 = "named_pipe1"; +static const std::string pipePath2 = "named_pipe2"; +static const std::string nonExistingPath = "non-existing"; static constexpr int kNumPipes = 2; static const std::string pipePaths[kNumPipes] = {pipePath1, pipePath2}; +class PipeWaiter { +public: + void notify() { + { + stdx::unique_lock lk(m); + pipeCreated = true; + } + cv.notify_one(); + } + + void wait() { + stdx::unique_lock lk(m); + cv.wait(lk, [&] { return pipeCreated; }); + } + +private: + Mutex m; + stdx::condition_variable cv; + bool pipeCreated = false; +}; + class ExternalRecordStoreTest : public unittest::Test { public: // Gets a random string of 'count' length consisting of printable ASCII chars (32-126). @@ -74,11 +84,11 @@ public: return buf; } - static constexpr int kBufferSize = 1024; char _buffer[kBufferSize]; // buffer amply big enough to fit any BSONObj used in this test - static void createNamedPipe(const std::string& pipePath, + static void createNamedPipe(PipeWaiter* pw, + const std::string& pipePath, long numToWrite, const std::vector<BSONObj>& bsonObjs); @@ -88,13 +98,16 @@ public: }; // Creates a named pipe of BSON objects. +// pipeWaiter - synchronization for pipe creation // pipePath - file path for the named pipe // numToWrite - number of bsons to write to the pipe // bsonObjs - vector of bsons to write round-robin to the pipe -void ExternalRecordStoreTest::createNamedPipe(const std::string& pipePath, +void ExternalRecordStoreTest::createNamedPipe(PipeWaiter* pw, + const std::string& pipePath, long numToWrite, const std::vector<BSONObj>& bsonObjs) { - NamedPipeOutput pipeWriter(pipePath.c_str()); + NamedPipeOutput pipeWriter(pipePath); + pw->notify(); pipeWriter.open(); const int numObjs = bsonObjs.size(); @@ -110,8 +123,10 @@ void ExternalRecordStoreTest::createNamedPipe(const std::string& pipePath, TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) { auto srcBsonObj = BSON("a" << 1); auto count = srcBsonObj.objsize(); + PipeWaiter pw; stdx::thread producer([&] { - NamedPipeOutput pipeWriter(pipePath1.c_str()); + NamedPipeOutput pipeWriter(pipePath1); + pw.notify(); pipeWriter.open(); for (int i = 0; i < 100; ++i) { @@ -123,11 +138,11 @@ TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) { ON_BLOCK_EXIT([&] { producer.join(); }); // Gives some time to the producer so that it can initialize a named pipe. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + pw.wait(); - auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str()); + auto inputStream = InputStream<NamedPipeInput>(pipePath1); for (int i = 0; i < 100; ++i) { - int nRead = inputStream->readBytes(count, _buffer); + int nRead = inputStream.readBytes(count, _buffer); ASSERT_EQ(nRead, count) << "Failed to read data up to {} bytes"_format(count); ASSERT_EQ(std::memcmp(srcBsonObj.objdata(), _buffer, count), 0) << "Read data is not same as the source data"; @@ -137,8 +152,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeBasicRead) { TEST_F(ExternalRecordStoreTest, NamedPipeReadPartialData) { auto srcBsonObj = BSON("a" << 1); auto count = srcBsonObj.objsize(); + PipeWaiter pw; stdx::thread producer([&] { - NamedPipeOutput pipeWriter(pipePath1.c_str()); + NamedPipeOutput pipeWriter(pipePath1); + pw.notify(); pipeWriter.open(); pipeWriter.write(srcBsonObj.objdata(), count); pipeWriter.close(); @@ -146,11 +163,11 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadPartialData) { ON_BLOCK_EXIT([&] { producer.join(); }); // Gives some time to the producer so that it can initialize a named pipe. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + pw.wait(); - auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str()); - // Request more data than the pipe contains. Should only get the bytes it does contain. - int nRead = inputStream->readBytes(kBufferSize, _buffer); + auto inputStream = InputStream<NamedPipeInput>(pipePath1); + // Requests more data than the pipe contains. Should only get the bytes it does contain. + int nRead = inputStream.readBytes(kBufferSize, _buffer); ASSERT_EQ(nRead, count) << "Expected nRead == {} but got {}"_format(count, nRead); ASSERT_EQ(std::memcmp(srcBsonObj.objdata(), _buffer, count), 0) << "Read data is not same as the source data"; @@ -160,8 +177,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) { auto srcBsonObj = BSON("a" << 1); auto count = srcBsonObj.objsize(); const auto nSent = std::rand() % 100; + PipeWaiter pw; stdx::thread producer([&] { - NamedPipeOutput pipeWriter(pipePath1.c_str()); + NamedPipeOutput pipeWriter(pipePath1); + pw.notify(); pipeWriter.open(); for (int i = 0; i < nSent; ++i) { @@ -173,12 +192,12 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) { ON_BLOCK_EXIT([&] { producer.join(); }); // Gives some time to the producer so that it can initialize a named pipe. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + pw.wait(); - auto inputStream = std::make_unique<InputStream<NamedPipeInput>>(pipePath1.c_str()); + auto inputStream = InputStream<NamedPipeInput>(pipePath1); auto nReceived = 0; while (true) { - int nRead = inputStream->readBytes(count, _buffer); + int nRead = inputStream.readBytes(count, _buffer); if (nRead != count) { ASSERT_EQ(nRead, 0) << "Expected nRead == 0 for EOF but got something else {}"_format( nRead); @@ -195,7 +214,7 @@ TEST_F(ExternalRecordStoreTest, NamedPipeReadUntilProducerDone) { TEST_F(ExternalRecordStoreTest, NamedPipeOpenNonExisting) { ASSERT_THROWS_CODE( - [] { (void)std::make_unique<InputStream<NamedPipeInput>>(nonExistingPath.c_str()); }(), + [] { (void)std::make_unique<InputStream<NamedPipeInput>>(nonExistingPath); }(), DBException, ErrorCodes::FileNotOpen); } @@ -210,9 +229,10 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes1) { // Create two pipes. The first has only "a" objects and the second has only "zed" objects. stdx::thread pipeThreads[kNumPipes]; + PipeWaiter pw[kNumPipes]; for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { - pipeThreads[pipeIdx] = - stdx::thread(createNamedPipe, pipePaths[pipeIdx], kObjsPerPipe, bsonObjs[pipeIdx]); + pipeThreads[pipeIdx] = stdx::thread( + createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], kObjsPerPipe, bsonObjs[pipeIdx]); } ON_BLOCK_EXIT([&] { for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { @@ -221,7 +241,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes1) { }); // Gives some time to the producers so they can initialize the named pipes. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { + pw[pipeIdx].wait(); + } // Create metadata describing the pipes and a MultiBsonStreamCursor to read them. VirtualCollectionOptions vopts; @@ -302,12 +324,13 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes2) { // so they will cause several wraps. The largish size of the objects makes it highly likely that // some reads will leave a partial object that must be completed on a later next() call. stdx::thread pipeThreads[kNumPipes]; + PipeWaiter pw[kNumPipes]; long numToWrites[] = {(3 * groupsIn32Mb * numObjs), (5 * groupsIn32Mb * numObjs)}; long numToWrite = 0; for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { - pipeThreads[pipeIdx] = - stdx::thread(createNamedPipe, pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs); + pipeThreads[pipeIdx] = stdx::thread( + createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs); numToWrite += numToWrites[pipeIdx]; } ON_BLOCK_EXIT([&] { @@ -317,7 +340,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes2) { }); // Gives some time to the producers so they can initialize the named pipes. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { + pw[pipeIdx].wait(); + } // Create metadata describing the pipes and a MultiBsonStreamCursor to read them. VirtualCollectionOptions vopts; @@ -383,12 +408,13 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes3) { // Create pipes with large bsons. stdx::thread pipeThreads[kNumPipes]; + PipeWaiter pw[kNumPipes]; long numToWrites[] = {19, 17}; long numToWrite = 0; for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { - pipeThreads[pipeIdx] = - stdx::thread(createNamedPipe, pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs); + pipeThreads[pipeIdx] = stdx::thread( + createNamedPipe, &pw[pipeIdx], pipePaths[pipeIdx], numToWrites[pipeIdx], bsonObjs); numToWrite += numToWrites[pipeIdx]; } ON_BLOCK_EXIT([&] { @@ -398,7 +424,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes3) { }); // Gives some time to the producers so they can initialize the named pipes. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { + pw[pipeIdx].wait(); + } // Create metadata describing the pipes and a MultiBsonStreamCursor to read them. VirtualCollectionOptions vopts; @@ -450,6 +478,7 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) { constexpr int kNumPipes = 100; // shadows the global std::string pipePaths[kNumPipes]; // shadows the global stdx::thread pipeThreads[kNumPipes]; // pipe producer threads + PipeWaiter pw[kNumPipes]; // pipe waiters std::vector<BSONObj> pipeBsonObjs[kNumPipes]; // vector of BSON objects for each pipe size_t objsWritten = 0; // number of objects written to all pipes @@ -465,12 +494,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) { // Create the pipes. for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { -#ifndef _WIN32 - pipePaths[pipeIdx] = "/tmp/named_pipe{}"_format(pipeIdx); -#else - pipePaths[pipeIdx] = "//./pipe/named_pipe{}"_format(pipeIdx); -#endif + pipePaths[pipeIdx] = "named_pipe{}"_format(pipeIdx); pipeThreads[pipeIdx] = stdx::thread(createNamedPipe, + &pw[pipeIdx], pipePaths[pipeIdx], pipeBsonObjs[pipeIdx].size(), pipeBsonObjs[pipeIdx]); @@ -482,7 +508,9 @@ TEST_F(ExternalRecordStoreTest, NamedPipeMultiplePipes4) { }); // Gives some time to the producers so they can initialize the named pipes. - stdx::this_thread::sleep_for(stdx::chrono::seconds(1)); + for (int pipeIdx = 0; pipeIdx < kNumPipes; ++pipeIdx) { + pw[pipeIdx].wait(); + } // Create metadata describing the pipes and a MultiBsonStreamCursor to read them. VirtualCollectionOptions vopts; diff --git a/src/mongo/db/storage/input_object.h b/src/mongo/db/storage/input_object.h index 9f420b8288b..5788d5c7b0b 100644 --- a/src/mongo/db/storage/input_object.h +++ b/src/mongo/db/storage/input_object.h @@ -41,7 +41,7 @@ class StreamableInput { public: virtual ~StreamableInput() {} - virtual const char* getPath() const = 0; + virtual const std::string& getAbsolutePath() const = 0; void open() { if (isOpen()) { diff --git a/src/mongo/db/storage/input_stream.h b/src/mongo/db/storage/input_stream.h index 95d0b6d9bd8..45f33522f80 100644 --- a/src/mongo/db/storage/input_stream.h +++ b/src/mongo/db/storage/input_stream.h @@ -32,12 +32,12 @@ #include <fmt/format.h> #include <utility> +#include "mongo/db/storage/io_error_message.h" #include "mongo/logv2/log.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage namespace mongo { - /** * This template class provides a standardized input facility over StreamableInput or SeekableInput. * @@ -57,7 +57,7 @@ public: using namespace fmt::literals; InputT::open(); uassert(ErrorCodes::FileNotOpen, - "error"_format(getErrorMessage("open"_sd, InputT::getPath())), + "error"_format(getErrorMessage("open"_sd, InputT::getAbsolutePath())), InputT::isOpen()); } @@ -95,14 +95,14 @@ public: // If we reach this point, we accumulated fewer than 'count' bytes. if (MONGO_likely(InputT::isEof())) { - LOGV2_INFO(7005001, "Named pipe is closed", "path"_attr = InputT::getPath()); + LOGV2_INFO(7005001, "Named pipe is closed", "path"_attr = InputT::getAbsolutePath()); return nReadTotal; } tassert(7005002, "Expected an error condition but succeeded", InputT::isFailed()); LOGV2_ERROR(7005003, "Failed to read a named pipe", - "error"_attr = getErrorMessage("read", InputT::getPath())); + "error"_attr = getErrorMessage("read", InputT::getAbsolutePath())); return -1; } diff --git a/src/mongo/db/storage/io_error_message.h b/src/mongo/db/storage/io_error_message.h new file mode 100644 index 00000000000..37979358286 --- /dev/null +++ b/src/mongo/db/storage/io_error_message.h @@ -0,0 +1,43 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <fmt/format.h> + +#include "mongo/util/errno_util.h" + +namespace mongo { +namespace { +inline std::string getErrorMessage(StringData op, const std::string& path) { + using namespace fmt::literals; + return "Failed to {} {}: {}"_format(op, path, errorMessage(lastSystemError())); +} +} // namespace +} // namespace mongo diff --git a/src/mongo/db/storage/multi_bson_stream_cursor.cpp b/src/mongo/db/storage/multi_bson_stream_cursor.cpp index 189131b0208..749a21e2aff 100644 --- a/src/mongo/db/storage/multi_bson_stream_cursor.cpp +++ b/src/mongo/db/storage/multi_bson_stream_cursor.cpp @@ -28,7 +28,10 @@ */ #include "mongo/db/storage/multi_bson_stream_cursor.h" + +#include "mongo/db/catalog/virtual_collection_options.h" #include "mongo/db/storage/record_store.h" + namespace mongo { /** @@ -37,10 +40,10 @@ namespace mongo { * and updates bookkeeping. This can never expand the buffer larger than (2 * BSONObjMaxUserSize). */ void MultiBsonStreamCursor::expandBuffer(int32_t bsonSize) { - tassert(6968308, + uassert(6968308, "bsonSize {} > BSONObjMaxUserSize {}"_format(bsonSize, BSONObjMaxUserSize), (bsonSize <= BSONObjMaxUserSize)); - tassert(6968309, "bsonSize {} < 0"_format(bsonSize), (bsonSize >= 0)); + uassert(6968309, "bsonSize {} < 0"_format(bsonSize), (bsonSize >= 0)); int newSizeTarget = 2 * bsonSize; do { @@ -143,6 +146,23 @@ boost::optional<Record> MultiBsonStreamCursor::nextFromCurrentStream() { } /** + * Returns an input stream for a named pipe mapped from 'url'. + * + * While creating an input stream, it strips off the file protocol part from the 'url'. + */ +std::unique_ptr<InputStream<NamedPipeInput>> MultiBsonStreamCursor::getInputStream( + const std::string& url) { + auto filePathPos = url.find(ExternalDataSourceMetadata::kUrlProtocolFile.toString()); + tassert( + ErrorCodes::BadValue, "Invalid file url: {}"_format(url), filePathPos != std::string::npos); + + auto filePathStr = + url.substr(filePathPos + ExternalDataSourceMetadata::kUrlProtocolFile.size()); + + return std::make_unique<InputStream<NamedPipeInput>>(filePathStr); +} + +/** * Returns the next record from the vector of streams or boost::none if exhausted or error. * '_streamReader' is initialized to the first stream, if there is one, in the constructor. */ @@ -154,8 +174,7 @@ boost::optional<Record> MultiBsonStreamCursor::next() { } ++_streamIdx; if (_streamIdx < _numStreams) { - _streamReader = std::make_unique<InputStream<NamedPipeInput>>( - _vopts.dataSources[_streamIdx].url.c_str()); + _streamReader = getInputStream(_vopts.dataSources[_streamIdx].url); } } return boost::none; diff --git a/src/mongo/db/storage/multi_bson_stream_cursor.h b/src/mongo/db/storage/multi_bson_stream_cursor.h index 1b896f6a6c5..d71e7240326 100644 --- a/src/mongo/db/storage/multi_bson_stream_cursor.h +++ b/src/mongo/db/storage/multi_bson_stream_cursor.h @@ -29,9 +29,10 @@ #pragma once -#include "mongo/db/storage/external_record_store.h" +#include "mongo/db/catalog/virtual_collection_options.h" #include "mongo/db/storage/input_stream.h" #include "mongo/db/storage/named_pipe.h" +#include "mongo/db/storage/record_store.h" namespace mongo { class MultiBsonStreamCursor : public SeekableRecordCursor { @@ -39,8 +40,7 @@ public: MultiBsonStreamCursor(const VirtualCollectionOptions& vopts) : _numStreams(vopts.dataSources.size()), _vopts(vopts) { tassert(6968310, "_numStreams {} <= 0"_format(_numStreams), _numStreams > 0); - _streamReader = - std::make_unique<InputStream<NamedPipeInput>>(_vopts.dataSources[0].url.c_str()); + _streamReader = getInputStream(_vopts.dataSources[_streamIdx].url); } boost::optional<Record> next() override; @@ -71,6 +71,7 @@ public: private: void expandBuffer(int32_t bsonSize); boost::optional<Record> nextFromCurrentStream(); + static std::unique_ptr<InputStream<NamedPipeInput>> getInputStream(const std::string& url); // The size in bytes of a BSON object's "size" prefix. static constexpr int kSizeSize = static_cast<int>(sizeof(int32_t)); diff --git a/src/mongo/db/storage/named_pipe.h b/src/mongo/db/storage/named_pipe.h index 61e6c87ac87..587661e5b6a 100644 --- a/src/mongo/db/storage/named_pipe.h +++ b/src/mongo/db/storage/named_pipe.h @@ -34,20 +34,30 @@ #else #include <windows.h> #endif +#include <string> #include "mongo/db/storage/input_object.h" namespace mongo { +#ifndef _WIN32 +static constexpr auto kDefaultPipePath = "/tmp/"_sd; +#else +// "//./pipe/" is the required path start of all named pipes on Windows, where "//." is the +// abbreviation for the local server name and "/pipe" is a literal. (These also work with +// Windows-native backslashes instead of forward slashes. +static constexpr auto kDefaultPipePath = "//./pipe/"_sd; +#endif + class NamedPipeOutput { public: - NamedPipeOutput(const char* pipePath); + NamedPipeOutput(const std::string& pipeRelativePath); ~NamedPipeOutput(); void open(); int write(const char* data, int size); void close(); private: - const char* _pipePath; + std::string _pipeAbsolutePath; #ifndef _WIN32 std::ofstream _ofs; #else @@ -58,10 +68,10 @@ private: class NamedPipeInput : public StreamableInput { public: - NamedPipeInput(const char* pipePath); + NamedPipeInput(const std::string& pipeRelativePath); ~NamedPipeInput() override; - const char* getPath() const override { - return _pipePath; + const std::string& getAbsolutePath() const override { + return _pipeAbsolutePath; } bool isOpen() const override; bool isGood() const override; @@ -74,7 +84,7 @@ protected: void doClose() override; private: - const char* _pipePath; + std::string _pipeAbsolutePath; #ifndef _WIN32 std::ifstream _ifs; #else diff --git a/src/mongo/db/storage/named_pipe_posix.cpp b/src/mongo/db/storage/named_pipe_posix.cpp index c1e04be6f27..c34079fc916 100644 --- a/src/mongo/db/storage/named_pipe_posix.cpp +++ b/src/mongo/db/storage/named_pipe_posix.cpp @@ -31,10 +31,11 @@ #include "named_pipe.h" #include <fmt/format.h> +#include <string> #include <sys/stat.h> #include <sys/types.h> -#include "mongo/db/storage/external_record_store.h" +#include "mongo/db/storage/io_error_message.h" #include "mongo/logv2/log.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage @@ -43,23 +44,26 @@ namespace mongo { using namespace fmt::literals; -NamedPipeOutput::NamedPipeOutput(const char* pipePath) : _pipePath(pipePath), _ofs() { - remove(_pipePath); +NamedPipeOutput::NamedPipeOutput(const std::string& pipeRelativePath) + : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), _ofs() { + remove(_pipeAbsolutePath.c_str()); uassert(7005005, - "Failed to create a named pipe, error={}"_format(getErrorMessage("mkfifo", _pipePath)), - mkfifo(_pipePath, 0664) == 0); + "Failed to create a named pipe, error={}"_format( + getErrorMessage("mkfifo", _pipeAbsolutePath)), + mkfifo(_pipeAbsolutePath.c_str(), 0664) == 0); } NamedPipeOutput::~NamedPipeOutput() { close(); + remove(_pipeAbsolutePath.c_str()); } void NamedPipeOutput::open() { - _ofs.open(_pipePath, std::ios::binary | std::ios::app); + _ofs.open(_pipeAbsolutePath.c_str(), std::ios::binary | std::ios::app); if (!_ofs.is_open() || !_ofs.good()) { LOGV2_ERROR(7005009, "Failed to open a named pipe", - "error"_attr = getErrorMessage("open", _pipePath)); + "error"_attr = getErrorMessage("open", _pipeAbsolutePath)); } } @@ -78,14 +82,15 @@ void NamedPipeOutput::close() { } } -NamedPipeInput::NamedPipeInput(const char* pipePath) : _pipePath(pipePath), _ifs() {} +NamedPipeInput::NamedPipeInput(const std::string& pipeRelativePath) + : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), _ifs() {} NamedPipeInput::~NamedPipeInput() { close(); } void NamedPipeInput::doOpen() { - _ifs.open(_pipePath, std::ios::binary | std::ios::in); + _ifs.open(_pipeAbsolutePath.c_str(), std::ios::binary | std::ios::in); } int NamedPipeInput::doRead(char* data, int size) { diff --git a/src/mongo/db/storage/named_pipe_windows.cpp b/src/mongo/db/storage/named_pipe_windows.cpp index 8b0aab88987..aa4d530cab9 100644 --- a/src/mongo/db/storage/named_pipe_windows.cpp +++ b/src/mongo/db/storage/named_pipe_windows.cpp @@ -31,9 +31,12 @@ #include "named_pipe.h" #include <fmt/format.h> +#include <string> +#include <system_error> -#include "mongo/db/storage/external_record_store.h" +#include "mongo/db/storage/io_error_message.h" #include "mongo/logv2/log.h" +#include "mongo/util/errno_util.h" #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage @@ -41,14 +44,20 @@ namespace mongo { using namespace fmt::literals; -NamedPipeOutput::NamedPipeOutput(const char* pipePath) - : _pipePath(pipePath), - _pipe(CreateNamedPipeA( - pipePath, PIPE_ACCESS_OUTBOUND, (PIPE_TYPE_BYTE | PIPE_WAIT), 1, 0, 0, 0, nullptr)), +NamedPipeOutput::NamedPipeOutput(const std::string& pipeRelativePath) + : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), + _pipe(CreateNamedPipeA(_pipeAbsolutePath.c_str(), + PIPE_ACCESS_OUTBOUND, + (PIPE_TYPE_BYTE | PIPE_WAIT), + 1, // nMaxInstances + 0, // nOutBufferSize + 0, // nInBufferSize + 0, // nDefaultTimeOut + nullptr)), // lpSecurityAttributes _isOpen(false) { uassert(7005006, "Failed to create a named pipe, error={}"_format( - getErrorMessage("CreateNamedPipe", _pipePath)), + getErrorMessage("CreateNamedPipe", _pipeAbsolutePath)), _pipe != INVALID_HANDLE_VALUE); } @@ -64,7 +73,7 @@ void NamedPipeOutput::open() { if (!res) { LOGV2_ERROR(7005007, "Failed to connect a named pipe", - "error"_attr = getErrorMessage("ConnectNamedPipe", _pipePath)); + "error"_attr = getErrorMessage("ConnectNamedPipe", _pipeAbsolutePath)); return; } _isOpen = true; @@ -83,7 +92,7 @@ int NamedPipeOutput::write(const char* data, int size) { if (!res || size != nWritten) { LOGV2_ERROR(7005008, "Failed to write to a named pipe", - "error"_attr = getErrorMessage("write", _pipePath)); + "error"_attr = getErrorMessage("write", _pipeAbsolutePath)); return -1; } @@ -102,8 +111,8 @@ void NamedPipeOutput::close() { } } -NamedPipeInput::NamedPipeInput(const char* pipePath) - : _pipePath(pipePath), +NamedPipeInput::NamedPipeInput(const std::string& pipeRelativePath) + : _pipeAbsolutePath(kDefaultPipePath + pipeRelativePath), _pipe(INVALID_HANDLE_VALUE), _isOpen(false), _isGood(false), @@ -114,7 +123,8 @@ NamedPipeInput::~NamedPipeInput() { } void NamedPipeInput::doOpen() { - _pipe = CreateFileA(_pipePath, GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr); + _pipe = + CreateFileA(_pipeAbsolutePath.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr); if (_pipe == INVALID_HANDLE_VALUE) { return; } diff --git a/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp b/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp index 0627f0e41ce..582df818b5a 100644 --- a/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp +++ b/src/mongo/db/storage/oplog_cap_maintainer_thread.cpp @@ -66,8 +66,7 @@ bool OplogCapMaintainerThread::_deleteExcessDocuments() { // Maintaining the Oplog cap is crucial to the stability of the server so that we don't let the // oplog grow unbounded. We mark the operation as having immediate priority to skip ticket // acquisition and flow control. - SetTicketAquisitionPriorityForLock priority(opCtx.get(), - AdmissionContext::Priority::kImmediate); + SetAdmissionPriorityForLock priority(opCtx.get(), AdmissionContext::Priority::kImmediate); try { // A Global IX lock should be good enough to protect the oplog truncation from diff --git a/src/mongo/db/storage/recovery_unit.h b/src/mongo/db/storage/recovery_unit.h index 20c49c637d3..9e173d632e7 100644 --- a/src/mongo/db/storage/recovery_unit.h +++ b/src/mongo/db/storage/recovery_unit.h @@ -37,6 +37,7 @@ #include "mongo/bson/timestamp.h" #include "mongo/db/repl/read_concern_level.h" #include "mongo/db/storage/snapshot.h" +#include "mongo/db/storage/storage_stats.h" #include "mongo/util/decorable.h" namespace mongo { @@ -76,37 +77,6 @@ enum class PrepareConflictBehavior { }; /** - * Storage statistics management class, with interfaces to provide the statistics in the BSON format - * and an operator to add the statistics values. - */ -class StorageStats { - StorageStats(const StorageStats&) = delete; - StorageStats& operator=(const StorageStats&) = delete; - -public: - StorageStats() = default; - - virtual ~StorageStats(){}; - - /** - * Provides the storage statistics in the form of a BSONObj. - */ - virtual BSONObj toBSON() = 0; - - /** - * Add the statistics values. - */ - virtual StorageStats& operator+=(const StorageStats&) = 0; - - /** - * Provides the ability to create an instance of this class outside of the storage integration - * layer. - */ - virtual std::shared_ptr<StorageStats> getCopy() = 0; -}; - - -/** * A RecoveryUnit is responsible for ensuring that data is persisted. * All on-disk information must be mutated through this interface. */ diff --git a/src/mongo/db/storage/storage_engine_impl.cpp b/src/mongo/db/storage/storage_engine_impl.cpp index b5b3b2f0d80..6eb5123cf92 100644 --- a/src/mongo/db/storage/storage_engine_impl.cpp +++ b/src/mongo/db/storage/storage_engine_impl.cpp @@ -971,10 +971,12 @@ Status StorageEngineImpl::_dropCollectionsNoTimestamp(OperationContext* opCtx, audit::logDropCollection(opCtx->getClient(), coll->ns()); - Status result = catalog::dropCollection( - opCtx, coll->ns(), coll->getCatalogId(), coll->getSharedIdent()); - if (!result.isOK() && firstError.isOK()) { - firstError = result; + if (auto sharedIdent = coll->getSharedIdent()) { + Status result = + catalog::dropCollection(opCtx, coll->ns(), coll->getCatalogId(), sharedIdent); + if (!result.isOK() && firstError.isOK()) { + firstError = result; + } } CollectionCatalog::get(opCtx)->dropCollection( diff --git a/src/mongo/db/query/optimizer/cascades/cost_derivation.h b/src/mongo/db/storage/storage_stats.h index b71a020316e..b0afd2ba16d 100644 --- a/src/mongo/db/query/optimizer/cascades/cost_derivation.h +++ b/src/mongo/db/storage/storage_stats.h @@ -29,22 +29,29 @@ #pragma once -#include "mongo/db/query/optimizer/cascades/interfaces.h" -#include "mongo/db/query/optimizer/cascades/memo.h" +#include "mongo/bson/bsonobj.h" -namespace mongo::optimizer::cascades { +namespace mongo { /** - * Default costing for physical nodes with logical delegator (not-yet-optimized) inputs. + * Manages statistics from the storage engine, allowing addition of statistics and serialization to + * BSON. */ -class DefaultCosting : public CostingInterface { +class StorageStats { public: - CostAndCE deriveCost(const Metadata& metadata, - const Memo& memo, - const properties::PhysProps& physProps, - ABT::reference_type physNodeRef, - const ChildPropsType& childProps, - const NodeCEMap& nodeCEMap) const override final; + StorageStats() = default; + + StorageStats(const StorageStats&) = delete; + StorageStats(StorageStats&&) = delete; + StorageStats& operator=(const StorageStats&) = delete; + + virtual ~StorageStats() = default; + + virtual BSONObj toBSON() const = 0; + + virtual std::shared_ptr<StorageStats> clone() const = 0; + + virtual StorageStats& operator+=(const StorageStats&) = 0; }; -} // namespace mongo::optimizer::cascades +} // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/SConscript b/src/mongo/db/storage/wiredtiger/SConscript index 68a76639ed5..638cf0b933c 100644 --- a/src/mongo/db/storage/wiredtiger/SConscript +++ b/src/mongo/db/storage/wiredtiger/SConscript @@ -37,6 +37,7 @@ wtEnv.Library( 'wiredtiger_index.cpp', 'wiredtiger_index_util.cpp', 'wiredtiger_kv_engine.cpp', + 'wiredtiger_operation_stats.cpp', 'wiredtiger_oplog_manager.cpp', 'wiredtiger_parameters.cpp', 'wiredtiger_prepare_conflict.cpp', @@ -133,6 +134,7 @@ wtEnv.CppUnitTest( 'wiredtiger_init_test.cpp', 'wiredtiger_c_api_test.cpp', 'wiredtiger_kv_engine_test.cpp', + 'wiredtiger_operation_stats_test.cpp', 'wiredtiger_recovery_unit_test.cpp', 'wiredtiger_session_cache_test.cpp', 'wiredtiger_util_test.cpp', diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp index 1f034c86e26..1a31948a00e 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_column_store.cpp @@ -39,7 +39,6 @@ #define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage - namespace mongo { StatusWith<std::string> WiredTigerColumnStore::generateCreateString( const std::string& engineName, @@ -218,9 +217,19 @@ void WiredTigerColumnStore::WriteCursor::update(PathView path, RowId rid, CellVi void WiredTigerColumnStore::fullValidate(OperationContext* opCtx, int64_t* numKeysOut, IndexValidateResults* fullResults) const { - // TODO SERVER-65484: Validation for column indexes. - // uasserted(ErrorCodes::NotImplemented, "WiredTigerColumnStore::fullValidate()"); - return; + dassert(opCtx->lockState()->isReadLocked()); + if (!WiredTigerIndexUtil::validateStructure(opCtx, _uri, fullResults)) { + return; + } + auto cursor = newCursor(opCtx); + long long count = 0; + + while (cursor->next()) { + count++; + } + if (numKeysOut) { + *numKeysOut = count; + } } class WiredTigerColumnStore::Cursor final : public ColumnStore::Cursor, diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp index a63022f4ad2..d5297781aaa 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index.cpp @@ -298,32 +298,8 @@ void WiredTigerIndex::fullValidate(OperationContext* opCtx, long long* numKeysOut, IndexValidateResults* fullResults) const { dassert(opCtx->lockState()->isReadLocked()); - if (fullResults && !WiredTigerRecoveryUnit::get(opCtx)->getSessionCache()->isEphemeral()) { - int err = WiredTigerUtil::verifyTable(opCtx, _uri, &(fullResults->errors)); - if (err == EBUSY) { - std::string msg = str::stream() - << "Could not complete validation of " << _uri << ". " - << "This is a transient issue as the collection was actively " - "in use by other operations."; - - LOGV2_WARNING(51781, - "Could not complete validation. This is a transient issue as " - "the collection was actively in use by other operations", - "uri"_attr = _uri); - fullResults->warnings.push_back(msg); - } else if (err) { - std::string msg = str::stream() - << "verify() returned " << wiredtiger_strerror(err) << ". " - << "This indicates structural damage. " - << "Not examining individual index entries."; - LOGV2_ERROR(51782, - "verify() returned an error. This indicates structural damage. Not " - "examining individual index entries.", - "error"_attr = wiredtiger_strerror(err)); - fullResults->errors.push_back(msg); - fullResults->valid = false; - return; - } + if (!WiredTigerIndexUtil::validateStructure(opCtx, _uri, fullResults)) { + return; } auto cursor = newCursor(opCtx); diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp index 97278595091..273b692c375 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.cpp @@ -30,6 +30,9 @@ #include "mongo/db/storage/wiredtiger/wiredtiger_index_util.h" #include "mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h" +#include "mongo/logv2/log.h" + +#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kStorage namespace mongo { @@ -115,4 +118,37 @@ bool WiredTigerIndexUtil::isEmpty(OperationContext* opCtx, return false; } +bool WiredTigerIndexUtil::validateStructure(OperationContext* opCtx, + const std::string& uri, + IndexValidateResults* fullResults) { + if (fullResults && !WiredTigerRecoveryUnit::get(opCtx)->getSessionCache()->isEphemeral()) { + int err = WiredTigerUtil::verifyTable(opCtx, uri, &(fullResults->errors)); + if (err == EBUSY) { + std::string msg = str::stream() + << "Could not complete validation of " << uri << ". " + << "This is a transient issue as the collection was actively " + "in use by other operations."; + + LOGV2_WARNING(51781, + "Could not complete validation. This is a transient issue as " + "the collection was actively in use by other operations", + "uri"_attr = uri); + fullResults->warnings.push_back(msg); + } else if (err) { + std::string msg = str::stream() + << "verify() returned " << wiredtiger_strerror(err) << ". " + << "This indicates structural damage. " + << "Not examining individual index entries."; + LOGV2_ERROR(51782, + "verify() returned an error. This indicates structural damage. Not " + "examining individual index entries.", + "error"_attr = wiredtiger_strerror(err)); + fullResults->errors.push_back(msg); + fullResults->valid = false; + return false; + } + } + return true; +} + } // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h index e98c5a7b1dc..e80067dd6d1 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_index_util.h @@ -31,6 +31,7 @@ #include "mongo/base/status.h" #include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/catalog/validate_results.h" #include "mongo/db/operation_context.h" namespace mongo { @@ -51,6 +52,10 @@ public: static Status compact(OperationContext* opCtx, const std::string& uri); static bool isEmpty(OperationContext* opCtx, const std::string& uri, uint64_t tableId); + + static bool validateStructure(OperationContext* opCtx, + const std::string& uri, + IndexValidateResults* fullResults); }; } // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp new file mode 100644 index 00000000000..ebf3842e7bc --- /dev/null +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.cpp @@ -0,0 +1,139 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h" + +#include "mongo/bson/bsonobjbuilder.h" +#include "mongo/db/storage/wiredtiger/wiredtiger_util.h" + +namespace mongo { +namespace { + +enum class StatType { kData, kWait }; + +struct StatInfo { + StringData name; + StatType type; +}; + +stdx::unordered_map<int, StatInfo> statInfo = { + {WT_STAT_SESSION_BYTES_READ, {"bytesRead"_sd, StatType::kData}}, + {WT_STAT_SESSION_BYTES_WRITE, {"bytesWritten"_sd, StatType::kData}}, + {WT_STAT_SESSION_LOCK_DHANDLE_WAIT, {"handleLock"_sd, StatType::kWait}}, + {WT_STAT_SESSION_READ_TIME, {"timeReadingMicros"_sd, StatType::kData}}, + {WT_STAT_SESSION_WRITE_TIME, {"timeWritingMicros"_sd, StatType::kData}}, + {WT_STAT_SESSION_LOCK_SCHEMA_WAIT, {"schemaLock"_sd, StatType::kWait}}, + {WT_STAT_SESSION_CACHE_TIME, {"cache"_sd, StatType::kWait}}}; + +} // namespace + +WiredTigerOperationStats::WiredTigerOperationStats(WT_SESSION* session) { + invariant(session); + + WT_CURSOR* c; + uassert(ErrorCodes::CursorNotFound, + "Unable to open statistics cursor", + !session->open_cursor(session, "statistics:session", nullptr, "statistics=(fast)", &c)); + + ScopeGuard guard{[c] { c->close(c); }}; + + int32_t key; + uint64_t value; + while (c->next(c) == 0 && c->get_key(c, &key) == 0) { + fassert(51035, c->get_value(c, nullptr, nullptr, &value) == 0); + _stats[key] = WiredTigerUtil::castStatisticsValue<long long>(value); + } + + // Reset the statistics so that the next fetch gives the recent values. + invariantWTOK(c->reset(c), c->session); +} + +BSONObj WiredTigerOperationStats::toBSON() const { + boost::optional<BSONObjBuilder> dataSection; + boost::optional<BSONObjBuilder> waitSection; + + for (auto&& [stat, value] : _stats) { + if (value == 0) { + continue; + } + + auto it = statInfo.find(stat); + if (it == statInfo.end()) { + continue; + } + auto&& [name, type] = it->second; + + auto appendToSection = [name = name, + value = value](boost::optional<BSONObjBuilder>& section) { + if (!section) { + section.emplace(); + } + section->append(name, value); + }; + + switch (type) { + case StatType::kData: + appendToSection(dataSection); + break; + case StatType::kWait: + appendToSection(waitSection); + break; + } + } + + BSONObjBuilder builder; + if (dataSection) { + builder.append("data", dataSection->obj()); + } + if (waitSection) { + builder.append("timeWaitingMicros", waitSection->obj()); + } + + return builder.obj(); +} + +std::shared_ptr<StorageStats> WiredTigerOperationStats::clone() const { + auto copy = std::make_shared<WiredTigerOperationStats>(); + *copy += *this; + return copy; +} + +WiredTigerOperationStats& WiredTigerOperationStats::operator+=( + const WiredTigerOperationStats& other) { + for (auto&& [stat, value] : other._stats) { + _stats[stat] += value; + } + return *this; +} + +StorageStats& WiredTigerOperationStats::operator+=(const StorageStats& other) { + return *this += checked_cast<const WiredTigerOperationStats&>(other); +} + +} // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h new file mode 100644 index 00000000000..581303e30b8 --- /dev/null +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h @@ -0,0 +1,55 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <wiredtiger.h> + +#include "mongo/db/storage/storage_stats.h" + +namespace mongo { + +class WiredTigerOperationStats final : public StorageStats { +public: + WiredTigerOperationStats() = default; + WiredTigerOperationStats(WT_SESSION* session); + + BSONObj toBSON() const final; + + std::shared_ptr<StorageStats> clone() const final; + + StorageStats& operator+=(const StorageStats&) final; + + WiredTigerOperationStats& operator+=(const WiredTigerOperationStats&); + +private: + std::map<int, long long> _stats; +}; + +} // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp new file mode 100644 index 00000000000..5151f59cb48 --- /dev/null +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_operation_stats_test.cpp @@ -0,0 +1,215 @@ +/** + * Copyright (C) 2022-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h" +#include "mongo/unittest/temp_dir.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +#define ASSERT_WT_OK(result) ASSERT_EQ(result, 0) << wiredtiger_strerror(result) + +class WiredTigerOperationStatsTest : public unittest::Test { +protected: + void setUp() override { + ASSERT_WT_OK( + wiredtiger_open(_path.path().c_str(), nullptr, "create,statistics=(fast),", &_conn)); + ASSERT_WT_OK(_conn->open_session(_conn, nullptr, "isolation=snapshot", &_session)); + ASSERT_WT_OK(_session->create( + _session, _uri.c_str(), "type=file,key_format=q,value_format=u,log=(enabled=false)")); + } + + void tearDown() override { + ASSERT_EQ(_conn->close(_conn, nullptr), 0); + } + + /** + * Writes the given data using WT. Causes the bytesWritten and timeWritingMicros stats to be + * incremented. + */ + void write(const std::string& data) { + ASSERT_WT_OK(_session->begin_transaction(_session, nullptr)); + + WT_CURSOR* cursor; + ASSERT_WT_OK(_session->open_cursor(_session, _uri.c_str(), nullptr, nullptr, &cursor)); + + cursor->set_key(cursor, _key++); + + WT_ITEM item{data.data(), data.size()}; + cursor->set_value(cursor, &item); + + ASSERT_WT_OK(cursor->insert(cursor)); + ASSERT_WT_OK(cursor->close(cursor)); + ASSERT_WT_OK(_session->commit_transaction(_session, nullptr)); + ASSERT_WT_OK(_session->checkpoint(_session, nullptr)); + } + + /** + * Reads all of the previously written data from WT. Causes the bytesRead and timeReadingMicros + * stats to be incremented. + */ + void read() { + tearDown(); + setUp(); + + ASSERT_WT_OK(_session->begin_transaction(_session, nullptr)); + + WT_CURSOR* cursor; + ASSERT_WT_OK(_session->open_cursor(_session, _uri.c_str(), nullptr, nullptr, &cursor)); + + for (int64_t i = 0; i < _key; ++i) { + cursor->set_key(cursor, i); + ASSERT_WT_OK(cursor->search(cursor)); + + WT_ITEM value; + ASSERT_WT_OK(cursor->get_value(cursor, &value)); + } + + ASSERT_WT_OK(cursor->close(cursor)); + ASSERT_WT_OK(_session->commit_transaction(_session, nullptr)); + } + + unittest::TempDir _path{"wiredtiger_operation_stats_test"}; + std::string _uri{"table:wiredtiger_operation_stats_test"}; + WT_CONNECTION* _conn; + WT_SESSION* _session; + int64_t _key = 0; +}; + +TEST_F(WiredTigerOperationStatsTest, Empty) { + ASSERT_BSONOBJ_EQ(WiredTigerOperationStats{_session}.toBSON(), BSONObj{}); +} + +TEST_F(WiredTigerOperationStatsTest, Write) { + write("a"); + + auto statsObj = WiredTigerOperationStats{_session}.toBSON(); + + auto dataSection = statsObj["data"]; + ASSERT_EQ(dataSection.type(), BSONType::Object) << statsObj; + + ASSERT(dataSection["bytesWritten"]) << statsObj; + for (auto&& [name, value] : dataSection.Obj()) { + ASSERT_EQ(value.type(), BSONType::NumberLong) << statsObj; + ASSERT_GT(value.numberLong(), 0) << statsObj; + } +} + +TEST_F(WiredTigerOperationStatsTest, Read) { + write("a"); + read(); + + auto statsObj = WiredTigerOperationStats{_session}.toBSON(); + + auto dataSection = statsObj["data"]; + ASSERT_EQ(dataSection.type(), BSONType::Object) << statsObj; + + ASSERT(dataSection["bytesRead"]) << statsObj; + for (auto&& [name, value] : dataSection.Obj()) { + ASSERT_EQ(value.type(), BSONType::NumberLong) << statsObj; + ASSERT_GT(value.numberLong(), 0) << statsObj; + } +} + +TEST_F(WiredTigerOperationStatsTest, Large) { + auto remaining = static_cast<int64_t>(std::numeric_limits<uint32_t>::max()) + 1; + while (remaining > 0) { + std::string data(1024 * 1024, 'a'); + remaining -= data.size(); + write(data); + } + + auto statsObj = WiredTigerOperationStats{_session}.toBSON(); + ASSERT_GT(statsObj["data"]["bytesWritten"].numberLong(), std::numeric_limits<uint32_t>::max()) + << statsObj; + + read(); + + statsObj = WiredTigerOperationStats{_session}.toBSON(); + ASSERT_GT(statsObj["data"]["bytesRead"].numberLong(), std::numeric_limits<uint32_t>::max()) + << statsObj; +} + +TEST_F(WiredTigerOperationStatsTest, Add) { + std::vector<std::unique_ptr<WiredTigerOperationStats>> stats; + + write("a"); + stats.push_back(std::make_unique<WiredTigerOperationStats>(_session)); + + read(); + stats.push_back(std::make_unique<WiredTigerOperationStats>(_session)); + + write("aa"); + stats.push_back(std::make_unique<WiredTigerOperationStats>(_session)); + + read(); + stats.push_back(std::make_unique<WiredTigerOperationStats>(_session)); + + long long bytesWritten = 0; + long long timeWritingMicros = 0; + long long bytesRead = 0; + long long timeReadingMicros = 0; + + WiredTigerOperationStats combined; + + for (auto&& op : stats) { + auto statsObj = op->toBSON(); + + bytesWritten += statsObj["data"]["bytesWritten"].numberLong(); + timeWritingMicros += statsObj["data"]["timeWritingMicros"].numberLong(); + bytesRead += statsObj["data"]["bytesRead"].numberLong(); + timeReadingMicros += statsObj["data"]["timeReadingMicros"].numberLong(); + + combined += *op; + } + + auto combinedObj = combined.toBSON(); + auto dataSection = combinedObj["data"]; + ASSERT_EQ(dataSection.type(), BSONType::Object) << combinedObj; + ASSERT_EQ(dataSection["bytesWritten"].numberLong(), bytesWritten) << combinedObj; + ASSERT_EQ(dataSection["timeWritingMicros"].numberLong(), timeWritingMicros) << combinedObj; + ASSERT_EQ(dataSection["bytesRead"].numberLong(), bytesRead) << combinedObj; + ASSERT_EQ(dataSection["timeReadingMicros"].numberLong(), timeReadingMicros) << combinedObj; +} + +TEST_F(WiredTigerOperationStatsTest, Clone) { + write("a"); + + WiredTigerOperationStats stats{_session}; + auto clone = stats.clone(); + + ASSERT_BSONOBJ_EQ(stats.toBSON(), clone->toBSON()); + + stats += *clone; + ASSERT_BSONOBJ_NE(stats.toBSON(), clone->toBSON()); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp index 1f4d9d567b7..26f91072bbd 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.cpp @@ -37,6 +37,7 @@ #include "mongo/db/server_options.h" #include "mongo/db/storage/wiredtiger/wiredtiger_begin_transaction_block.h" #include "mongo/db/storage/wiredtiger/wiredtiger_kv_engine.h" +#include "mongo/db/storage/wiredtiger/wiredtiger_operation_stats.h" #include "mongo/db/storage/wiredtiger/wiredtiger_prepare_conflict.h" #include "mongo/db/storage/wiredtiger/wiredtiger_session_cache.h" #include "mongo/db/storage/wiredtiger/wiredtiger_util.h" @@ -82,107 +83,6 @@ void handleWriteContextForDebugging(WiredTigerRecoveryUnit& ru, Timestamp& ts) { AtomicWord<std::int64_t> snapshotTooOldErrorCount{0}; -using Section = WiredTigerOperationStats::Section; - -std::map<int, std::pair<StringData, Section>> WiredTigerOperationStats::_statNameMap = { - {WT_STAT_SESSION_BYTES_READ, std::make_pair("bytesRead"_sd, Section::DATA)}, - {WT_STAT_SESSION_BYTES_WRITE, std::make_pair("bytesWritten"_sd, Section::DATA)}, - {WT_STAT_SESSION_LOCK_DHANDLE_WAIT, std::make_pair("handleLock"_sd, Section::WAIT)}, - {WT_STAT_SESSION_READ_TIME, std::make_pair("timeReadingMicros"_sd, Section::DATA)}, - {WT_STAT_SESSION_WRITE_TIME, std::make_pair("timeWritingMicros"_sd, Section::DATA)}, - {WT_STAT_SESSION_LOCK_SCHEMA_WAIT, std::make_pair("schemaLock"_sd, Section::WAIT)}, - {WT_STAT_SESSION_CACHE_TIME, std::make_pair("cache"_sd, Section::WAIT)}}; - -std::shared_ptr<StorageStats> WiredTigerOperationStats::getCopy() { - std::shared_ptr<WiredTigerOperationStats> copy = std::make_shared<WiredTigerOperationStats>(); - *copy += *this; - return copy; -} - -void WiredTigerOperationStats::fetchStats(WT_SESSION* session, - const std::string& uri, - const std::string& config) { - invariant(session); - - WT_CURSOR* c = nullptr; - const char* cursorConfig = config.empty() ? nullptr : config.c_str(); - int ret = session->open_cursor(session, uri.c_str(), nullptr, cursorConfig, &c); - uassert(ErrorCodes::CursorNotFound, "Unable to open statistics cursor", ret == 0); - - invariant(c); - ON_BLOCK_EXIT([&] { c->close(c); }); - - const char* desc; - uint64_t value; - int32_t key; - while (c->next(c) == 0 && c->get_key(c, &key) == 0) { - fassert(51035, c->get_value(c, &desc, nullptr, &value) == 0); - _stats[key] = WiredTigerUtil::castStatisticsValue<long long>(value); - } - - // Reset the statistics so that the next fetch gives the recent values. - invariantWTOK(c->reset(c), c->session); -} - -BSONObj WiredTigerOperationStats::toBSON() { - BSONObjBuilder bob; - std::unique_ptr<BSONObjBuilder> dataSection; - std::unique_ptr<BSONObjBuilder> waitSection; - - for (auto const& stat : _stats) { - // Find the user consumable name for this statistic. - auto statIt = _statNameMap.find(stat.first); - - // Ignore the session statistic that is not reported for the slow operations. - if (statIt == _statNameMap.end()) - continue; - - auto statName = statIt->second.first; - Section subs = statIt->second.second; - long long val = stat.second; - // Add this statistic only if higher than zero. - if (val > 0) { - // Gather the statistic into its own subsection in the BSONObj. - switch (subs) { - case Section::DATA: - if (!dataSection) - dataSection = std::make_unique<BSONObjBuilder>(); - - dataSection->append(statName, val); - break; - case Section::WAIT: - if (!waitSection) - waitSection = std::make_unique<BSONObjBuilder>(); - - waitSection->append(statName, val); - break; - default: - MONGO_UNREACHABLE; - } - } - } - - if (dataSection) - bob.append("data", dataSection->obj()); - if (waitSection) - bob.append("timeWaitingMicros", waitSection->obj()); - - return bob.obj(); -} - -WiredTigerOperationStats& WiredTigerOperationStats::operator+=( - const WiredTigerOperationStats& other) { - for (auto const& otherStat : other._stats) { - _stats[otherStat.first] += otherStat.second; - } - return (*this); -} - -StorageStats& WiredTigerOperationStats::operator+=(const StorageStats& other) { - *this += checked_cast<const WiredTigerOperationStats&>(other); - return (*this); -} - WiredTigerRecoveryUnit::WiredTigerRecoveryUnit(WiredTigerSessionCache* sc) : WiredTigerRecoveryUnit(sc, sc->getKVEngine()->getOplogManager()) {} @@ -996,18 +896,7 @@ void WiredTigerRecoveryUnit::beginIdle() { } std::shared_ptr<StorageStats> WiredTigerRecoveryUnit::getOperationStatistics() const { - std::shared_ptr<WiredTigerOperationStats> statsPtr(nullptr); - - if (!_session) - return statsPtr; - - WT_SESSION* s = _session->getSession(); - invariant(s); - - statsPtr = std::make_shared<WiredTigerOperationStats>(); - statsPtr->fetchStats(s, "statistics:session", "statistics=(fast)"); - - return statsPtr; + return _session ? std::make_shared<WiredTigerOperationStats>(_session->getSession()) : nullptr; } void WiredTigerRecoveryUnit::setCatalogConflictingTimestamp(Timestamp timestamp) { diff --git a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h index 26102686dbd..2d38888dc90 100644 --- a/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h +++ b/src/mongo/db/storage/wiredtiger/wiredtiger_recovery_unit.h @@ -56,41 +56,6 @@ extern AtomicWord<std::int64_t> snapshotTooOldErrorCount; class BSONObjBuilder; -class WiredTigerOperationStats final : public StorageStats { -public: - /** - * There are two types of statistics provided by WiredTiger engine - data and wait. - */ - enum class Section { DATA, WAIT }; - - BSONObj toBSON() final; - - StorageStats& operator+=(const StorageStats&) final; - - WiredTigerOperationStats& operator+=(const WiredTigerOperationStats&); - - /** - * Fetches an operation's storage statistics from WiredTiger engine. - */ - void fetchStats(WT_SESSION*, const std::string&, const std::string&); - - std::shared_ptr<StorageStats> getCopy() final; - -private: - /** - * Each statistic in WiredTiger has an integer key, which this map associates with a section - * (either DATA or WAIT) and user-readable name. - */ - static std::map<int, std::pair<StringData, Section>> _statNameMap; - - /** - * Stores the value for each statistic returned by a WiredTiger cursor. Each statistic is - * associated with an integer key, which can be mapped to a name and section using the - * '_statNameMap'. - */ - std::map<int, long long> _stats; -}; - class WiredTigerRecoveryUnit final : public RecoveryUnit { public: WiredTigerRecoveryUnit(WiredTigerSessionCache* sc); diff --git a/src/mongo/db/timeseries/bucket_catalog.cpp b/src/mongo/db/timeseries/bucket_catalog.cpp index 31cb39aeab0..854a2805d66 100644 --- a/src/mongo/db/timeseries/bucket_catalog.cpp +++ b/src/mongo/db/timeseries/bucket_catalog.cpp @@ -860,7 +860,11 @@ Status BucketCatalog::prepareCommit(std::shared_ptr<WriteBatch> batch) { _useBucketInState(&stripe, stripeLock, batch->bucket().id, BucketState::kPrepared); if (batch->finished()) { - // Someone may have aborted it while we were waiting. + // Someone may have aborted it while we were waiting. Since we have the prepared batch, we + // should now be able to fully abort the bucket. + if (bucket) { + _abort(&stripe, stripeLock, batch, getBatchStatus()); + } return getBatchStatus(); } else if (!bucket) { _abort(&stripe, stripeLock, batch, getTimeseriesBucketClearedError(batch->bucket().id)); diff --git a/src/mongo/db/timeseries/bucket_catalog_test.cpp b/src/mongo/db/timeseries/bucket_catalog_test.cpp index 6733f4c9d89..7253111cbf2 100644 --- a/src/mongo/db/timeseries/bucket_catalog_test.cpp +++ b/src/mongo/db/timeseries/bucket_catalog_test.cpp @@ -911,6 +911,73 @@ TEST_F(BucketCatalogTest, CannotConcurrentlyCommitBatchesForSameBucket) { _bucketCatalog->finish(batch2, {}); } +TEST_F(BucketCatalogTest, AbortingBatchEnsuresBucketIsEventuallyClosed) { + auto batch1 = _bucketCatalog + ->insert(_opCtx, + _ns1, + _getCollator(_ns1), + _getTimeseriesOptions(_ns1), + BSON(_timeField << Date_t::now()), + BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow) + .getValue() + .batch; + + auto batch2 = _bucketCatalog + ->insert(_makeOperationContext().second.get(), + _ns1, + _getCollator(_ns1), + _getTimeseriesOptions(_ns1), + BSON(_timeField << Date_t::now()), + BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow) + .getValue() + .batch; + auto batch3 = _bucketCatalog + ->insert(_makeOperationContext().second.get(), + _ns1, + _getCollator(_ns1), + _getTimeseriesOptions(_ns1), + BSON(_timeField << Date_t::now()), + BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow) + .getValue() + .batch; + ASSERT_EQ(batch1->bucket().id, batch2->bucket().id); + ASSERT_EQ(batch1->bucket().id, batch3->bucket().id); + + ASSERT(batch1->claimCommitRights()); + ASSERT(batch2->claimCommitRights()); + ASSERT(batch3->claimCommitRights()); + + // Batch 2 will not be able to commit until batch 1 has finished. + ASSERT_OK(_bucketCatalog->prepareCommit(batch1)); + auto task = Task{[&]() { ASSERT_OK(_bucketCatalog->prepareCommit(batch2)); }}; + // Add a little extra wait to make sure prepareCommit actually gets to the blocking point. + stdx::this_thread::sleep_for(stdx::chrono::milliseconds(10)); + ASSERT(task.future().valid()); + ASSERT(stdx::future_status::timeout == task.future().wait_for(stdx::chrono::microseconds(1))) + << "prepareCommit finished before expected"; + + // If we abort the third batch, it should abort the second one too, as it isn't prepared. + // However, since the first batch is prepared, we can't abort it or clean up the bucket. We can + // then finish the first batch, which will allow the second batch to proceed. It should + // recognize it has been aborted and clean up the bucket. + _bucketCatalog->abort(batch3, Status{ErrorCodes::TimeseriesBucketCleared, "cleared"}); + _bucketCatalog->finish(batch1, {}); + task.future().wait(); + ASSERT(batch2->finished()); + + // Make sure a new batch ends up in a new bucket. + auto batch4 = _bucketCatalog + ->insert(_opCtx, + _ns1, + _getCollator(_ns1), + _getTimeseriesOptions(_ns1), + BSON(_timeField << Date_t::now()), + BucketCatalog::CombineWithInsertsFromOtherClients::kDisallow) + .getValue() + .batch; + ASSERT_NE(batch2->bucket().id, batch4->bucket().id); +} + TEST_F(BucketCatalogTest, DuplicateNewFieldNamesAcrossConcurrentBatches) { auto batch1 = _bucketCatalog ->insert(_opCtx, diff --git a/src/mongo/db/transaction/transaction_api.cpp b/src/mongo/db/transaction/transaction_api.cpp index 36635d16694..1150c5b5750 100644 --- a/src/mongo/db/transaction/transaction_api.cpp +++ b/src/mongo/db/transaction/transaction_api.cpp @@ -300,17 +300,32 @@ SemiFuture<BSONObj> SEPTransactionClient::runCommand(StringData dbName, BSONObj BSONObjBuilder cmdBuilder(_behaviors->maybeModifyCommand(std::move(cmdObj))); _hooks->runRequestHook(&cmdBuilder); + auto modifiedCmdObj = cmdBuilder.obj(); + bool isAbortTxnCmd = modifiedCmdObj.firstElementFieldNameStringData() == "abortTransaction"; auto client = _serviceContext->makeClient("SEP-internal-txn-client"); AlternativeClientRegion clientRegion(client); + // Note that _token is only cancelled once the caller of the transaction no longer cares about // its result, so CancelableOperationContexts only being interrupted by ErrorCodes::Interrupted - // shouldn't impact any upstream retry logic. - CancelableOperationContextFactory opCtxFactory(_token, _executor); + // shouldn't impact any upstream retry logic. If a _bestEffortAbort() is invoked, a new + // cancelation token must be used in constructing the opCtx for running abortTransaction. This + // is because an operation that has already been interrupted would cancel the parent cancelation + // token and using that same token to send abortTransaction would fail to send abortTransaction, + // leaving the transaction open longer than necessary. + auto opCtxFactory = isAbortTxnCmd + ? CancelableOperationContextFactory(CancellationToken::uncancelable(), _executor) + : CancelableOperationContextFactory(_token, _executor); + auto cancellableOpCtx = opCtxFactory.makeOperationContext(&cc()); + + // abortTransaction should still be interruptible on stepdown/shutdown. + if (isAbortTxnCmd) { + cancellableOpCtx->setAlwaysInterruptAtStepDownOrUp_UNSAFE(); + } primeInternalClient(&cc()); - auto opMsgRequest = OpMsgRequest::fromDBAndBody(dbName, cmdBuilder.obj()); + auto opMsgRequest = OpMsgRequest::fromDBAndBody(dbName, modifiedCmdObj); auto requestMessage = opMsgRequest.serialize(); return _behaviors->handleRequest(cancellableOpCtx.get(), requestMessage) .then([this](DbResponse dbResponse) { diff --git a/src/mongo/db/transaction/transaction_metrics_observer.cpp b/src/mongo/db/transaction/transaction_metrics_observer.cpp index 0a503e3b135..7a6544a1b3b 100644 --- a/src/mongo/db/transaction/transaction_metrics_observer.cpp +++ b/src/mongo/db/transaction/transaction_metrics_observer.cpp @@ -188,7 +188,7 @@ void TransactionMetricsObserver::onTransactionOperation(OperationContext* opCtx, if (storageStats) { CurOp::get(opCtx)->debug().storageStats = storageStats; if (!_singleTransactionStats.getOpDebug()->storageStats) { - _singleTransactionStats.getOpDebug()->storageStats = storageStats->getCopy(); + _singleTransactionStats.getOpDebug()->storageStats = storageStats->clone(); } else { *_singleTransactionStats.getOpDebug()->storageStats += *storageStats; } diff --git a/src/mongo/db/ttl.cpp b/src/mongo/db/ttl.cpp index 12c02d2c090..f56b3985261 100644 --- a/src/mongo/db/ttl.cpp +++ b/src/mongo/db/ttl.cpp @@ -320,7 +320,7 @@ void TTLMonitor::shutdown() { void TTLMonitor::_doTTLPass() { const ServiceContext::UniqueOperationContext opCtxPtr = cc().makeOperationContext(); OperationContext* opCtx = opCtxPtr.get(); - SetTicketAquisitionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); + SetAdmissionPriorityForLock priority(opCtx, AdmissionContext::Priority::kLow); hangTTLMonitorBetweenPasses.pauseWhileSet(opCtx); |