diff options
author | Eric Cox <eric.cox@mongodb.com> | 2020-10-02 18:59:55 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-12-01 21:03:49 +0000 |
commit | 664e5759c78489a88e66e587c34e5f7127c7eebf (patch) | |
tree | 898dde369a563772009f4030d730e0ae18510bec /src/mongo/db/query | |
parent | a71a2a8bfecf7de0807a28e3eabf9412dddd4258 (diff) | |
download | mongo-664e5759c78489a88e66e587c34e5f7127c7eebf.tar.gz |
SERVER-50712 Handle shard filtering in SBE
Diffstat (limited to 'src/mongo/db/query')
-rw-r--r-- | src/mongo/db/query/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_shard_filter_test.cpp | 221 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.cpp | 135 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.h | 11 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.cpp | 53 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.h | 30 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_test.cpp | 22 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_test_fixture.cpp | 11 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_test_fixture.h | 20 | ||||
-rw-r--r-- | src/mongo/db/query/shard_filterer_factory_impl.cpp | 44 | ||||
-rw-r--r-- | src/mongo/db/query/shard_filterer_factory_impl.h | 49 | ||||
-rw-r--r-- | src/mongo/db/query/shard_filterer_factory_interface.h | 46 | ||||
-rw-r--r-- | src/mongo/db/query/shard_filterer_factory_mock.cpp | 43 | ||||
-rw-r--r-- | src/mongo/db/query/shard_filterer_factory_mock.h | 52 | ||||
-rw-r--r-- | src/mongo/db/query/stage_builder_util.cpp | 12 |
15 files changed, 719 insertions, 32 deletions
diff --git a/src/mongo/db/query/SConscript b/src/mongo/db/query/SConscript index b4741d22748..54f0300bf7c 100644 --- a/src/mongo/db/query/SConscript +++ b/src/mongo/db/query/SConscript @@ -339,6 +339,8 @@ env.CppUnitTest( "query_solution_test.cpp", "sbe_stage_builder_test_fixture.cpp", "sbe_stage_builder_test.cpp", + "sbe_shard_filter_test.cpp", + "shard_filterer_factory_mock.cpp", "view_response_formatter_test.cpp", ], LIBDEPS=[ diff --git a/src/mongo/db/query/sbe_shard_filter_test.cpp b/src/mongo/db/query/sbe_shard_filter_test.cpp new file mode 100644 index 00000000000..5f92418e9d4 --- /dev/null +++ b/src/mongo/db/query/sbe_shard_filter_test.cpp @@ -0,0 +1,221 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/exec/shard_filterer_mock.h" +#include "mongo/db/query/query_solution.h" +#include "mongo/db/query/sbe_stage_builder_test_fixture.h" +#include "mongo/db/query/shard_filterer_factory_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { + +const NamespaceString kTestNss("TestDB", "TestColl"); + +class SbeShardFilterTest : public SbeStageBuilderTestFixture { +protected: + /** + * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer that always passes. + */ + std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysPassShardFiltererFactory( + const BSONObj& shardKeyPattern) { + return std::make_unique<ShardFiltererFactoryMock>( + std::make_unique<ConstantFilterMock>(true, shardKeyPattern)); + } + + /** + * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer that always fails. + */ + std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysFailShardFiltererFactory( + const BSONObj& shardKeyPattern) { + return std::make_unique<ShardFiltererFactoryMock>( + std::make_unique<ConstantFilterMock>(false, shardKeyPattern)); + } + + /** + * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer to filter out docs + * containing all null values along the shard key. + */ + std::unique_ptr<ShardFiltererFactoryInterface> makeAllNullShardKeyFiltererFactory( + const BSONObj& shardKeyPattern) { + return std::make_unique<ShardFiltererFactoryMock>( + std::make_unique<AllNullShardKeyFilterMock>(shardKeyPattern)); + } + + + /** + * Makes a new QuerySolutionNode consisting of a ShardingFilterNode and a child VirtualScanNode. + */ + std::unique_ptr<QuerySolutionNode> makeFilterVirtualScanTree(std::vector<BSONArray> docs) { + auto virtScan = std::make_unique<VirtualScanNode>(docs, false); + auto shardFilter = std::make_unique<ShardingFilterNode>(); + shardFilter->children.push_back(virtScan.release()); + return std::move(shardFilter); + } + + /** + * Runs a unit test with a given shard filterer factory and asserts that the results match the + * expected docs. + */ + void runTest(std::vector<BSONArray> docs, + const BSONArray& expected, + std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererFactory) { + // Construct a QuerySolutionNode consisting of a ShardingFilterNode with a single child + // VirtualScanNode. + auto shardFilter = makeFilterVirtualScanTree(docs); + auto querySolution = makeQuerySolution(std::move(shardFilter)); + + // Translate the QuerySolution to an sbe::PlanStage. + auto [resultSlots, stage, data] = + buildPlanStage(std::move(querySolution), false, std::move(shardFiltererFactory)); + + // Prepare the sbe::PlanStage for execution and collect all results. + auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots); + auto [resultsTag, resultsVal] = getAllResults(stage.get(), resultAccessors[0]); + sbe::value::ValueGuard resultGuard{resultsTag, resultsVal}; + + // Convert the expected results to an sbe value and assert results. + auto [expectedTag, expectedVal] = stage_builder::makeValue(expected); + sbe::value::ValueGuard expectedGuard{expectedTag, expectedVal}; + ASSERT_TRUE(valueEquals(resultsTag, resultsVal, expectedTag, expectedVal)); + } +}; + +TEST_F(SbeShardFilterTest, AlwaysPassFilter) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)), + BSON_ARRAY(BSON("a" << 2 << "b" << 2)), + BSON_ARRAY(BSON("a" << 3 << "b" << 2))}; + auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2) + << BSON("a" << 2 << "b" << 2) << BSON("a" << 3 << "b" << 2)); + runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1))); +} + +TEST_F(SbeShardFilterTest, AlwaysFailFilter) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)), + BSON_ARRAY(BSON("a" << 2 << "b" << 2)), + BSON_ARRAY(BSON("a" << 3 << "b" << 2))}; + auto expected = BSONArray(); + runTest(docs, expected, makeAlwaysFailShardFiltererFactory(BSON("a" << 1))); +} + +TEST_F(SbeShardFilterTest, ArrayAlongLeafShardKeyGetsFiltered) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)), + BSON_ARRAY(BSON("a" << 2 << "b" << 2)), + BSON_ARRAY(BSON("a" << 3 << "b" << BSON_ARRAY(1 << 2)))}; + + auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2) << BSON("a" << 2 << "b" << 2)); + runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1 << "b" << 1))); +} + +TEST_F(SbeShardFilterTest, TopLevelArrayShardKeyGetsFiltered) { + auto docs = std::vector<BSONArray>{ + BSON_ARRAY(BSON("a" << BSON("b" << 1))), + BSON_ARRAY(BSON("a" << BSON("b" << 2))), + BSON_ARRAY(BSON("a" << BSON_ARRAY(BSON("b" << 1) << BSON("b" << 2))))}; + + auto expected = BSON_ARRAY(BSON("a" << BSON("b" << 1)) << BSON("a" << BSON("b" << 2))); + runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b" << 1))); +} + +TEST_F(SbeShardFilterTest, ArrayAlongBiggerShardKeyGetsFiltered) { + auto docs = std::vector<BSONArray>{ + BSON_ARRAY(BSON("a" << 1 << "b" << 2 << "c" << 3 << "d" << 4)), + BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 3 << "d" << 4)), + BSON_ARRAY(BSON("a" << BSON_ARRAY(1 << 2) << "b" << 2 << "c" << 3 << "d" << 4)), + BSON_ARRAY(BSON("a" << 3 << "b" << 2 << "c" << 3 << "d" << 4))}; + + auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2 << "c" << 3 << "d" << 4) + << BSON("a" << 2 << "b" << 2 << "c" << 3 << "d" << 4) + << BSON("a" << 3 << "b" << 2 << "c" << 3 << "d" << 4)); + runTest( + docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1 << "b" << 1 << "c" << 1))); +} + +TEST_F(SbeShardFilterTest, ArrayInDottedPathKeyGetsFiltered) { + auto docs = + std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << 1) << "c" << 2)), + BSON_ARRAY(BSON("a" << BSON("b" << 2) << "c" << 2)), + BSON_ARRAY(BSON("a" << BSON("b" << BSON_ARRAY(1 << 2)) << "c" << 2)), + BSON_ARRAY(BSON("a" << BSON("b" << 3) << "c" << 2))}; + + auto expected = BSON_ARRAY(BSON("a" << BSON("b" << 1) << "c" << 2) + << BSON("a" << BSON("b" << 2) << "c" << 2) + << BSON("a" << BSON("b" << 3) << "c" << 2)); + runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b" << 1))); +} + +TEST_F(SbeShardFilterTest, ArrayAlongDeepDottedPathGetsFiltered) { + auto docs = std::vector<BSONArray>{ + BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << 1))))))), + BSON_ARRAY(BSON( + "a" << BSON( + "b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << BSON_ARRAY(1 << 2))))))))}; + auto expected = + BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << 1))))))); + runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b.c.d.e.f" << 1))); +} + +TEST_F(SbeShardFilterTest, MissingFieldsAreFilledCorrectly) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 1 << "c" << 2)), + BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 2)), + BSON_ARRAY(BSON("c" << 2))}; + + auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 1 << "c" << 2) + << BSON("a" << 2 << "b" << 2 << "c" << 2)); + runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a" << 1 << "b" << 1))); +} + +TEST_F(SbeShardFilterTest, MissingFieldsDottedPathFilledCorrectly) { + auto docs = + std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << 1))), + BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))}; + + auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1))))); + runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1))); +} + +TEST_F(SbeShardFilterTest, MissingFieldsAtTopDottedPathFilledCorrectly) { + auto docs = + std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1)), + BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))}; + + auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1))))); + runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1))); +} + +TEST_F(SbeShardFilterTest, MissingFieldsAtBottomDottedPathFilledCorrectly) { + auto docs = + std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << 1)))), + BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))}; + + auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1))))); + runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1))); +} +} // namespace mongo diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index 71adc9906b1..320e8289c1d 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -46,6 +46,7 @@ #include "mongo/db/exec/sbe/stages/traverse.h" #include "mongo/db/exec/sbe/stages/union.h" #include "mongo/db/exec/sbe/stages/unique.h" +#include "mongo/db/exec/shard_filterer.h" #include "mongo/db/fts/fts_index_format.h" #include "mongo/db/fts/fts_query_impl.h" #include "mongo/db/fts/fts_spec.h" @@ -56,6 +57,7 @@ #include "mongo/db/query/sbe_stage_builder_index_scan.h" #include "mongo/db/query/sbe_stage_builder_projection.h" #include "mongo/db/query/util/make_data_structure.h" +#include "mongo/db/s/collection_sharding_state.h" namespace mongo::stage_builder { std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment( @@ -119,10 +121,12 @@ SlotBasedStageBuilder::SlotBasedStageBuilder(OperationContext* opCtx, const CanonicalQuery& cq, const QuerySolution& solution, PlanYieldPolicySBE* yieldPolicy, - bool needsTrialRunProgressTracker) + bool needsTrialRunProgressTracker, + ShardFiltererFactoryInterface* shardFiltererFactory) : StageBuilder(opCtx, collection, cq, solution), _yieldPolicy(yieldPolicy), - _data(makeRuntimeEnvironment(_opCtx, &_slotIdGenerator)) { + _data(makeRuntimeEnvironment(_opCtx, &_slotIdGenerator)), + _shardFiltererFactory(shardFiltererFactory) { if (needsTrialRunProgressTracker) { const auto maxNumResults{trial_period::getTrialPeriodNumToReturn(_cq)}; @@ -140,6 +144,11 @@ SlotBasedStageBuilder::SlotBasedStageBuilder(OperationContext* opCtx, _data.shouldTrackResumeToken = csn->requestResumeToken; _data.shouldUseTailableScan = csn->tailable; } + + if (auto node = getNodeByType(solution.root(), STAGE_VIRTUAL_SCAN)) { + auto vsn = static_cast<const VirtualScanNode*>(node); + _shouldProduceRecordIdSlot = vsn->hasRecordId; + } } std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolutionNode* root) { @@ -147,21 +156,23 @@ std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolution invariant(!_buildHasStarted); _buildHasStarted = true; - // We always produce a 'resultSlot' and a 'recordIdSlot'. If the solution contains a - // CollectionScanNode with the 'shouldTrackLatestOplogTimestamp' flag set to true, then we - // will also produce an 'oplogTsSlot'. + // We always produce a 'resultSlot' and conditionally produce a 'recordIdSlot' based on the + // 'shouldProduceRecordIdSlot'. If the solution contains a CollectionScanNode with the + // 'shouldTrackLatestOplogTimestamp' flag set to true, then we will also produce an + // 'oplogTsSlot'. PlanStageReqs reqs; reqs.set(kResult); - reqs.set(kRecordId); + reqs.setIf(kRecordId, _shouldProduceRecordIdSlot); reqs.setIf(kOplogTs, _data.shouldTrackLatestOplogTimestamp); // Build the SBE plan stage tree. auto [stage, outputs] = build(root, reqs); - // Assert that we produced a 'resultSlot' and a 'recordIdSlot'. Also assert that we produced - // an 'oplogTsSlot' if it's needed. + // Assert that we produced a 'resultSlot' and that we prouced a 'recordIdSlot' if the + // 'shouldProduceRecordIdSlot' flag was set. Also assert that we produced an 'oplogTsSlot' if + // it's needed. invariant(outputs.has(kResult)); - invariant(outputs.has(kRecordId)); + invariant(!_shouldProduceRecordIdSlot || outputs.has(kRecordId)); invariant(!_data.shouldTrackLatestOplogTimestamp || outputs.has(kOplogTs)); _data.outputs = std::move(outputs); @@ -203,6 +214,10 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildVirtualScan( const QuerySolutionNode* root, const PlanStageReqs& reqs) { auto vsn = static_cast<const VirtualScanNode*>(root); + invariant(!reqs.getIndexKeyBitset()); + + // Virtual scans cannot produce an oplogTsSlot, so assert that the caller doesn't need it. + invariant(!reqs.has(kOplogTs)); auto [inputTag, inputVal] = sbe::value::makeNewArray(); sbe::value::ValueGuard inputGuard{inputTag, inputVal}; @@ -221,11 +236,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder if (vsn->hasRecordId) { invariant(scanSlots.size() == 2); - outputs.set(PlanStageSlots::kRecordId, scanSlots[0]); - outputs.set(PlanStageSlots::kResult, scanSlots[1]); + outputs.set(kRecordId, scanSlots[0]); + outputs.set(kResult, scanSlots[1]); } else { invariant(scanSlots.size() == 1); - outputs.set(PlanStageSlots::kResult, scanSlots[0]); + invariant(!reqs.has(kRecordId)); + outputs.set(kResult, scanSlots[0]); } return {std::move(scanStage), std::move(outputs)}; @@ -999,6 +1015,98 @@ SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* roo return {std::move(unionStage), std::move(outputs)}; } +std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildShardFilter( + const QuerySolutionNode* root, const PlanStageReqs& reqs) { + using namespace std::literals; + + const auto filterNode = static_cast<const ShardingFilterNode*>(root); + + uassert(5071201, + "STAGE_SHARD_FILTER is curently only supported in SBE for collection scan plans", + filterNode->children[0]->getType() == StageType::STAGE_COLLSCAN || + filterNode->children[0]->getType() == StageType::STAGE_VIRTUAL_SCAN); + + auto childReqs = reqs.copy().set(kResult); + auto [stage, outputs] = build(filterNode->children[0], childReqs); + + // If we're sharded make sure that we don't return data that isn't owned by the shard. This + // situation can occur when pending documents from in-progress migrations are inserted and when + // there are orphaned documents from aborted migrations. To check if the document is owned by + // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a + // BSONObj. + auto shardFilterer = _shardFiltererFactory->makeShardFilterer(_opCtx); + + // Build an expression to extract the shard key from the document based on the shard key + // pattern. To do this, we iterate over the shard key pattern parts and build nested 'getField' + // expressions. This will handle single-element paths, and dotted paths for each shard key part. + sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projections; + sbe::value::SlotVector fieldSlots; + std::vector<std::string> projectFields; + std::unique_ptr<sbe::EExpression> bindShardKeyPart; + + BSONObjIterator keyPatternIter(shardFilterer->getKeyPattern().toBSON()); + while (auto keyPatternElem = keyPatternIter.next()) { + auto fieldRef = FieldRef{keyPatternElem.fieldNameStringData()}; + fieldSlots.push_back(_slotIdGenerator.generate()); + projectFields.push_back(fieldRef.dottedField().toString()); + + auto currentFieldSlot = sbe::makeE<sbe::EVariable>(outputs.get(kResult)); + auto shardKeyBinding = + generateShardKeyBinding(fieldRef, _frameIdGenerator, std::move(currentFieldSlot), 0); + + projections.emplace(fieldSlots.back(), std::move(shardKeyBinding)); + } + + auto shardKeySlot{_slotIdGenerator.generate()}; + + // Build an object which will hold a flattened shard key from the projections above. + auto shardKeyObjStage = sbe::makeS<sbe::MakeObjStage>( + sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projections), root->nodeId()), + shardKeySlot, + boost::none, + std::vector<std::string>{}, + projectFields, + fieldSlots, + true, + false, + root->nodeId()); + + // Build a project stage that checks if any of the fieldSlots for the shard key parts are an + // Array which is represented by Nothing. + invariant(fieldSlots.size() > 0); + auto arrayChecks = makeNot(sbe::makeE<sbe::EFunction>( + "exists", sbe::makeEs(sbe::makeE<sbe::EVariable>(fieldSlots[0])))); + for (size_t ind = 1; ind < fieldSlots.size(); ++ind) { + arrayChecks = sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::Op::logicOr, + std::move(arrayChecks), + makeNot(sbe::makeE<sbe::EFunction>( + "exists", sbe::makeEs(sbe::makeE<sbe::EVariable>(fieldSlots[ind]))))); + } + arrayChecks = sbe::makeE<sbe::EIf>(std::move(arrayChecks), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0), + sbe::makeE<sbe::EVariable>(shardKeySlot)); + + auto finalShardKeySlot{_slotIdGenerator.generate()}; + + auto finalShardKeyObjStage = makeProjectStage( + std::move(shardKeyObjStage), root->nodeId(), finalShardKeySlot, std::move(arrayChecks)); + + // Build a 'FilterStage' to skip over documents that don't belong to the shard. Shard membership + // of the document is checked by invoking 'shardFilter' with the owned 'ShardFilterer' along + // with the shard key that sits in the 'finalShardKeySlot' of 'MakeObjStage'. + auto shardFilterFn = sbe::makeE<sbe::EFunction>( + "shardFilter"sv, + sbe::makeEs(sbe::makeE<sbe::EConstant>( + sbe::value::TypeTags::shardFilterer, + sbe::value::bitcastFrom<ShardFilterer*>(shardFilterer.release())), + sbe::makeE<sbe::EVariable>(finalShardKeySlot))); + + return {sbe::makeS<sbe::FilterStage<false>>( + std::move(finalShardKeyObjStage), std::move(shardFilterFn), root->nodeId()), + std::move(outputs)}; +} + // Returns a non-null pointer to the root of a plan tree, or a non-OK status if the PlanStage tree // could not be constructed. std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::build( @@ -1024,7 +1132,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder {STAGE_TEXT, &SlotBasedStageBuilder::buildText}, {STAGE_RETURN_KEY, &SlotBasedStageBuilder::buildReturnKey}, {STAGE_EOF, &SlotBasedStageBuilder::buildEof}, - {STAGE_SORT_MERGE, &SlotBasedStageBuilder::buildSortMerge}}; + {STAGE_SORT_MERGE, &SlotBasedStageBuilder::buildSortMerge}, + {STAGE_SHARDING_FILTER, &SlotBasedStageBuilder::buildShardFilter}}; uassert(4822884, str::stream() << "Can't build exec tree for node: " << root->toString(), diff --git a/src/mongo/db/query/sbe_stage_builder.h b/src/mongo/db/query/sbe_stage_builder.h index d4b618dba34..ee13f494280 100644 --- a/src/mongo/db/query/sbe_stage_builder.h +++ b/src/mongo/db/query/sbe_stage_builder.h @@ -35,6 +35,7 @@ #include "mongo/db/exec/trial_period_utils.h" #include "mongo/db/exec/trial_run_progress_tracker.h" #include "mongo/db/query/plan_yield_policy_sbe.h" +#include "mongo/db/query/shard_filterer_factory_interface.h" #include "mongo/db/query/stage_builder.h" namespace mongo::stage_builder { @@ -255,7 +256,8 @@ public: const CanonicalQuery& cq, const QuerySolution& solution, PlanYieldPolicySBE* yieldPolicy, - bool needsTrialRunProgressTracker); + bool needsTrialRunProgressTracker, + ShardFiltererFactoryInterface* shardFilterer); std::unique_ptr<sbe::PlanStage> build(const QuerySolutionNode* root) final; @@ -324,6 +326,9 @@ private: std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> makeUnionForTailableCollScan( const QuerySolutionNode* root, const PlanStageReqs& reqs); + std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildShardFilter( + const QuerySolutionNode* root, const PlanStageReqs& reqs); + sbe::value::SlotIdGenerator _slotIdGenerator; sbe::value::FrameIdGenerator _frameIdGenerator; sbe::value::SpoolIdGenerator _spoolIdGenerator; @@ -335,5 +340,9 @@ private: PlanStageData _data; bool _buildHasStarted{false}; + bool _shouldProduceRecordIdSlot{true}; + + // A factory to construct shard filters. + ShardFiltererFactoryInterface* _shardFiltererFactory; }; } // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index 7e7c45ded2a..acac413ebd9 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -151,6 +151,59 @@ std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpre sbe::value::bitcastFrom<bool>(false)))); } +std::unique_ptr<sbe::EExpression> makeFillEmptyNull(std::unique_ptr<sbe::EExpression> e) { + using namespace std::literals; + return sbe::makeE<sbe::EFunction>( + "fillEmpty"sv, + sbe::makeEs(std::move(e), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0))); +} + +std::unique_ptr<sbe::EExpression> makeNothingArrayCheck( + std::unique_ptr<sbe::EExpression> isArrayInput, std::unique_ptr<sbe::EExpression> otherwise) { + using namespace std::literals; + return sbe::makeE<sbe::EIf>( + sbe::makeE<sbe::EFunction>("isArray"sv, sbe::makeEs(std::move(isArrayInput))), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0), + std::move(otherwise)); +} + +std::unique_ptr<sbe::EExpression> generateShardKeyBinding( + const FieldRef& keyPatternField, + sbe::value::FrameIdGenerator& frameIdGenerator, + std::unique_ptr<sbe::EExpression> inputExpr, + int level) { + using namespace std::literals; + invariant(level >= 0); + + auto makeGetFieldKeyPattern = [&](std::unique_ptr<sbe::EExpression> slot) { + return makeFillEmptyNull(sbe::makeE<sbe::EFunction>( + "getField"sv, + sbe::makeEs(std::move(slot), sbe::makeE<sbe::EConstant>([&]() { + const auto fieldName = keyPatternField[level]; + return std::string_view{fieldName.rawData(), fieldName.size()}; + }())))); + }; + + if (level == keyPatternField.numParts() - 1) { + auto frameId = frameIdGenerator.generate(); + auto bindSlot = sbe::makeE<sbe::EVariable>(frameId, 0); + return sbe::makeE<sbe::ELocalBind>( + frameId, + sbe::makeEs(makeGetFieldKeyPattern(std::move(inputExpr))), + makeNothingArrayCheck(bindSlot->clone(), bindSlot->clone())); + } + + auto frameId = frameIdGenerator.generate(); + auto nextSlot = sbe::makeE<sbe::EVariable>(frameId, 0); + auto shardKeyBinding = + generateShardKeyBinding(keyPatternField, frameIdGenerator, nextSlot->clone(), level + 1); + + return sbe::makeE<sbe::ELocalBind>( + frameId, + sbe::makeEs(makeGetFieldKeyPattern(inputExpr->clone())), + makeNothingArrayCheck(nextSlot->clone(), std::move(shardKeyBinding))); +} + EvalStage makeLimitCoScanStage(PlanNodeId planNodeId, long long limit) { return {makeLimitCoScanTree(planNodeId, limit), sbe::makeSV()}; } diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index 454665e65a7..91e201f5bc7 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -165,6 +165,36 @@ inline std::unique_ptr<sbe::EExpression> makeFunction(std::string_view name, Arg } /** + * Check if expression returns Nothing and return null if so. Otherwise, return the + * expression. + */ +std::unique_ptr<sbe::EExpression> makeFillEmptyNull(std::unique_ptr<sbe::EExpression> e); + +/** + * Check if expression returns an array and return Nothing if so. Otherwise, return the expression. + */ +std::unique_ptr<sbe::EExpression> makeNothingArrayCheck( + std::unique_ptr<sbe::EExpression> isArrayInput, std::unique_ptr<sbe::EExpression> otherwise); + +/** + * Creates an expression to extract a shard key part from inputExpr. The generated expression is a + * let binding that binds a getField expression to extract the shard key part value from the + * inputExpr. The entire let binding evaluates to a constant expression carrying the Nothing value + * if the binding is an array. Otherwise, it evaluates to a fillEmpty null expression. Here is an + * example expression generated from this function for a shard key pattern {'a.b': 1}: + * + * let [l1.0 = getField (s1, "a")] + * if (isArray (l1.0), NOTHING, + * let [l2.0 = getField (l1.0, "b")] + * if (isArray (l2.0), NOTHING, fillEmpty (l2.0, null))) + */ +std::unique_ptr<sbe::EExpression> generateShardKeyBinding( + const FieldRef& keyPatternField, + sbe::value::FrameIdGenerator& frameIdGenerator, + std::unique_ptr<sbe::EExpression> inputExpr, + int level); + +/** * If given 'EvalExpr' already contains a slot, simply returns it. Otherwise, allocates a new slot * and creates project stage to assign expression to this new slot. After that, new slot and project * stage are returned. diff --git a/src/mongo/db/query/sbe_stage_builder_test.cpp b/src/mongo/db/query/sbe_stage_builder_test.cpp index c830f33bdcb..a35066c6a28 100644 --- a/src/mongo/db/query/sbe_stage_builder_test.cpp +++ b/src/mongo/db/query/sbe_stage_builder_test.cpp @@ -29,15 +29,23 @@ #include "mongo/platform/basic.h" +#include "mongo/db/exec/shard_filterer_mock.h" #include "mongo/db/query/query_solution.h" #include "mongo/db/query/sbe_stage_builder_test_fixture.h" +#include "mongo/db/query/shard_filterer_factory_mock.h" #include "mongo/unittest/unittest.h" namespace mongo { -using SBEStageBuilderTest = SBEStageBuilderTestFixture; +class SbeStageBuilderTest : public SbeStageBuilderTestFixture { +protected: + std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysPassShardFiltererInterface() { + return std::make_unique<ShardFiltererFactoryMock>( + std::make_unique<ConstantFilterMock>(true, BSONObj{})); + } +}; -TEST_F(SBEStageBuilderTest, TestVirtualScan) { +TEST_F(SbeStageBuilderTest, TestVirtualScan) { auto docs = std::vector<BSONArray>{BSON_ARRAY(int64_t{0} << BSON("a" << 1 << "b" << 2)), BSON_ARRAY(int64_t{1} << BSON("a" << 2 << "b" << 2)), BSON_ARRAY(int64_t{2} << BSON("a" << 3 << "b" << 2))}; @@ -51,7 +59,9 @@ TEST_F(SBEStageBuilderTest, TestVirtualScan) { ASSERT_EQ(querySolution->root()->nodeId(), 1); // Translate the QuerySolution tree to an sbe::PlanStage. - auto [resultSlots, stage, data] = buildPlanStage(std::move(querySolution), true); + auto shardFiltererInterface = makeAlwaysPassShardFiltererInterface(); + auto [resultSlots, stage, data] = + buildPlanStage(std::move(querySolution), true, std::move(shardFiltererInterface)); auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots); int64_t index = 0; @@ -70,7 +80,7 @@ TEST_F(SBEStageBuilderTest, TestVirtualScan) { ASSERT_EQ(index, 3); } -TEST_F(SBEStageBuilderTest, TestLimitOneVirtualScan) { +TEST_F(SbeStageBuilderTest, TestLimitOneVirtualScan) { auto docs = std::vector<BSONArray>{BSON_ARRAY(int64_t{0} << BSON("a" << 1 << "b" << 2)), BSON_ARRAY(int64_t{1} << BSON("a" << 2 << "b" << 2)), BSON_ARRAY(int64_t{2} << BSON("a" << 3 << "b" << 2))}; @@ -86,7 +96,9 @@ TEST_F(SBEStageBuilderTest, TestLimitOneVirtualScan) { auto querySolution = makeQuerySolution(std::move(limitNode)); // Translate the QuerySolution tree to an sbe::PlanStage. - auto [resultSlots, stage, data] = buildPlanStage(std::move(querySolution), true); + auto shardFiltererInterface = makeAlwaysPassShardFiltererInterface(); + auto [resultSlots, stage, data] = + buildPlanStage(std::move(querySolution), true, std::move(shardFiltererInterface)); // Prepare the sbe::PlanStage for execution. auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots); diff --git a/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp b/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp index 31b8b6005d7..97dc4ffcd79 100644 --- a/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp +++ b/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp @@ -37,7 +37,7 @@ #include "mongo/db/query/sbe_stage_builder_test_fixture.h" namespace mongo { -std::unique_ptr<QuerySolution> SBEStageBuilderTestFixture::makeQuerySolution( +std::unique_ptr<QuerySolution> SbeStageBuilderTestFixture::makeQuerySolution( std::unique_ptr<QuerySolutionNode> root) { auto querySoln = std::make_unique<QuerySolution>(); querySoln->setRoot(std::move(root)); @@ -45,8 +45,10 @@ std::unique_ptr<QuerySolution> SBEStageBuilderTestFixture::makeQuerySolution( } std::tuple<sbe::value::SlotVector, std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> -SBEStageBuilderTestFixture::buildPlanStage(std::unique_ptr<QuerySolution> querySolution, - bool hasRecordId) { +SbeStageBuilderTestFixture::buildPlanStage( + std::unique_ptr<QuerySolution> querySolution, + bool hasRecordId, + std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererInterface) { auto qr = std::make_unique<QueryRequest>(_nss); const boost::intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest(_nss)); auto statusWithCQ = CanonicalQuery::canonicalize(opCtx(), std::move(qr), expCtx); @@ -57,7 +59,8 @@ SBEStageBuilderTestFixture::buildPlanStage(std::unique_ptr<QuerySolution> queryS *statusWithCQ.getValue(), *querySolution, nullptr /* YieldPolicy */, - false}; + false, + shardFiltererInterface.get()}; auto stage = builder.build(querySolution->root()); auto data = builder.getPlanStageData(); diff --git a/src/mongo/db/query/sbe_stage_builder_test_fixture.h b/src/mongo/db/query/sbe_stage_builder_test_fixture.h index d38382e8c71..840ab113a4b 100644 --- a/src/mongo/db/query/sbe_stage_builder_test_fixture.h +++ b/src/mongo/db/query/sbe_stage_builder_test_fixture.h @@ -34,11 +34,12 @@ #include "mongo/db/exec/sbe/values/slot.h" #include "mongo/db/exec/sbe/values/value.h" #include "mongo/db/query/query_solution.h" +#include "mongo/db/query/shard_filterer_factory_interface.h" #include "mongo/unittest/unittest.h" namespace mongo { /** - * SBEStageBuilderTestFixture is a unittest fixture that can be used to facilitate testing the + * SbeStageBuilderTestFixture is a unittest fixture that can be used to facilitate testing the * translation of a QuerySolution tree to an sbe PlanStage tree. * * The main mechanism that enables the whole sbe::PlanStage tree to be exercised under unittests is @@ -51,9 +52,9 @@ namespace mongo { * prepare the sbe::PlanStage tree. The sbe::PlanStageData returned from buildPlanStage() must be * kept alive across buildPlanStage(), prepareTree() and execution of the plan. */ -class SBEStageBuilderTestFixture : public sbe::PlanStageTestFixture { +class SbeStageBuilderTestFixture : public sbe::PlanStageTestFixture { public: - SBEStageBuilderTestFixture() = default; + SbeStageBuilderTestFixture() = default; /** * Makes a QuerySolution from a QuerySolutionNode. @@ -66,11 +67,16 @@ public: * SlotVector. If hasRecordId is 'true' then the returned SlotVector will carry a SlotId in the * 0th position for the RecordId and a SlotId for the BSONObj representation of the document in * the 1st position. Otherwise, if hasRecordId is 'false', the SlotVector will contain a single - * SlotId for the BSONObj representation of the document. + * SlotId for the BSONObj representation of the document. A real or mock + * ShardFiltererFactoryInterface must be provided so the sbe SlotBasedStageBuilder can build and + * utilize a ShardFilterer instance during translation of a ShardingFilterNode. */ - std:: - tuple<sbe::value::SlotVector, std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData> - buildPlanStage(std::unique_ptr<QuerySolution> querySolution, bool hasRecordId); + std::tuple<sbe::value::SlotVector, + std::unique_ptr<sbe::PlanStage>, + stage_builder::PlanStageData> + buildPlanStage(std::unique_ptr<QuerySolution> querySolution, + bool hasRecordId, + std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererFactoryInterface); private: const NamespaceString _nss = NamespaceString{"testdb.sbe_stage_builder"}; diff --git a/src/mongo/db/query/shard_filterer_factory_impl.cpp b/src/mongo/db/query/shard_filterer_factory_impl.cpp new file mode 100644 index 00000000000..0db775c411e --- /dev/null +++ b/src/mongo/db/query/shard_filterer_factory_impl.cpp @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/exec/shard_filterer_impl.h" +#include "mongo/db/query/shard_filterer_factory_impl.h" +#include "mongo/db/s/collection_sharding_state.h" + +namespace mongo { + +std::unique_ptr<ShardFilterer> ShardFiltererFactoryImpl::makeShardFilterer( + OperationContext* opCtx) const { + auto css = CollectionShardingState::get(opCtx, _collection->ns()); + return std::make_unique<ShardFiltererImpl>(css->getOwnershipFilter( + opCtx, CollectionShardingState::OrphanCleanupPolicy::kDisallowOrphanCleanup)); +} +} // namespace mongo diff --git a/src/mongo/db/query/shard_filterer_factory_impl.h b/src/mongo/db/query/shard_filterer_factory_impl.h new file mode 100644 index 00000000000..957b8154dba --- /dev/null +++ b/src/mongo/db/query/shard_filterer_factory_impl.h @@ -0,0 +1,49 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/catalog/collection.h" +#include "mongo/db/query/shard_filterer_factory_interface.h" + +namespace mongo { + +/** + * An implementation of ShardFiltererFactoryInterface. + */ +class ShardFiltererFactoryImpl : public ShardFiltererFactoryInterface { +public: + ShardFiltererFactoryImpl(const CollectionPtr& collection) : _collection(collection) {} + + std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const override; + +private: + const CollectionPtr& _collection; +}; +} // namespace mongo diff --git a/src/mongo/db/query/shard_filterer_factory_interface.h b/src/mongo/db/query/shard_filterer_factory_interface.h new file mode 100644 index 00000000000..acb9e8e81fa --- /dev/null +++ b/src/mongo/db/query/shard_filterer_factory_interface.h @@ -0,0 +1,46 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/exec/shard_filterer.h" +#include "mongo/db/operation_context.h" + +namespace mongo { + +/** + * An interface that can be used to construct a ShardFilterer. + */ +class ShardFiltererFactoryInterface { +public: + virtual ~ShardFiltererFactoryInterface() = default; + + virtual std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const = 0; +}; +} // namespace mongo diff --git a/src/mongo/db/query/shard_filterer_factory_mock.cpp b/src/mongo/db/query/shard_filterer_factory_mock.cpp new file mode 100644 index 00000000000..2e28150d1ec --- /dev/null +++ b/src/mongo/db/query/shard_filterer_factory_mock.cpp @@ -0,0 +1,43 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/query/shard_filterer_factory_mock.h" + +namespace mongo { + +ShardFiltererFactoryMock::ShardFiltererFactoryMock(std::unique_ptr<ShardFilterer> shardFilterer) + : _shardFilterer(std::move(shardFilterer)) {} + +std::unique_ptr<ShardFilterer> ShardFiltererFactoryMock::makeShardFilterer( + OperationContext* opCtx) const { + return _shardFilterer->clone(); +} +} // namespace mongo diff --git a/src/mongo/db/query/shard_filterer_factory_mock.h b/src/mongo/db/query/shard_filterer_factory_mock.h new file mode 100644 index 00000000000..356c1cc9774 --- /dev/null +++ b/src/mongo/db/query/shard_filterer_factory_mock.h @@ -0,0 +1,52 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/exec/shard_filterer.h" +#include "mongo/db/query/shard_filterer_factory_interface.h" + +namespace mongo { + +/** + * An implementation of ShardFiltererFactoryInterface for unit testing. + */ +class ShardFiltererFactoryMock : public ShardFiltererFactoryInterface { +public: + ShardFiltererFactoryMock(std::unique_ptr<ShardFilterer> shardFilter); + + /* + * Makes a new mock ShardFilterer. + */ + std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const override; + +private: + std::unique_ptr<ShardFilterer> _shardFilterer; +}; +} // namespace mongo diff --git a/src/mongo/db/query/stage_builder_util.cpp b/src/mongo/db/query/stage_builder_util.cpp index 58bda2900ca..f318bb127eb 100644 --- a/src/mongo/db/query/stage_builder_util.cpp +++ b/src/mongo/db/query/stage_builder_util.cpp @@ -34,6 +34,7 @@ #include "mongo/db/query/classic_stage_builder.h" #include "mongo/db/query/plan_yield_policy.h" #include "mongo/db/query/sbe_stage_builder.h" +#include "mongo/db/query/shard_filterer_factory_impl.h" namespace mongo::stage_builder { std::unique_ptr<PlanStage> buildClassicExecutableTree(OperationContext* opCtx, @@ -69,8 +70,15 @@ buildSlotBasedExecutableTree(OperationContext* opCtx, auto sbeYieldPolicy = dynamic_cast<PlanYieldPolicySBE*>(yieldPolicy); invariant(sbeYieldPolicy); - auto builder = std::make_unique<SlotBasedStageBuilder>( - opCtx, collection, cq, solution, sbeYieldPolicy, needsTrialRunProgressTracker); + auto shardFilterer = std::make_unique<ShardFiltererFactoryImpl>(collection); + + auto builder = std::make_unique<SlotBasedStageBuilder>(opCtx, + collection, + cq, + solution, + sbeYieldPolicy, + needsTrialRunProgressTracker, + shardFilterer.get()); auto root = builder->build(solution.root()); auto data = builder->getPlanStageData(); |