summaryrefslogtreecommitdiff
path: root/src/mongo/db/query
diff options
context:
space:
mode:
authorEric Cox <eric.cox@mongodb.com>2020-10-02 18:59:55 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-12-01 21:03:49 +0000
commit664e5759c78489a88e66e587c34e5f7127c7eebf (patch)
tree898dde369a563772009f4030d730e0ae18510bec /src/mongo/db/query
parenta71a2a8bfecf7de0807a28e3eabf9412dddd4258 (diff)
downloadmongo-664e5759c78489a88e66e587c34e5f7127c7eebf.tar.gz
SERVER-50712 Handle shard filtering in SBE
Diffstat (limited to 'src/mongo/db/query')
-rw-r--r--src/mongo/db/query/SConscript2
-rw-r--r--src/mongo/db/query/sbe_shard_filter_test.cpp221
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp135
-rw-r--r--src/mongo/db/query/sbe_stage_builder.h11
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp53
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h30
-rw-r--r--src/mongo/db/query/sbe_stage_builder_test.cpp22
-rw-r--r--src/mongo/db/query/sbe_stage_builder_test_fixture.cpp11
-rw-r--r--src/mongo/db/query/sbe_stage_builder_test_fixture.h20
-rw-r--r--src/mongo/db/query/shard_filterer_factory_impl.cpp44
-rw-r--r--src/mongo/db/query/shard_filterer_factory_impl.h49
-rw-r--r--src/mongo/db/query/shard_filterer_factory_interface.h46
-rw-r--r--src/mongo/db/query/shard_filterer_factory_mock.cpp43
-rw-r--r--src/mongo/db/query/shard_filterer_factory_mock.h52
-rw-r--r--src/mongo/db/query/stage_builder_util.cpp12
15 files changed, 719 insertions, 32 deletions
diff --git a/src/mongo/db/query/SConscript b/src/mongo/db/query/SConscript
index b4741d22748..54f0300bf7c 100644
--- a/src/mongo/db/query/SConscript
+++ b/src/mongo/db/query/SConscript
@@ -339,6 +339,8 @@ env.CppUnitTest(
"query_solution_test.cpp",
"sbe_stage_builder_test_fixture.cpp",
"sbe_stage_builder_test.cpp",
+ "sbe_shard_filter_test.cpp",
+ "shard_filterer_factory_mock.cpp",
"view_response_formatter_test.cpp",
],
LIBDEPS=[
diff --git a/src/mongo/db/query/sbe_shard_filter_test.cpp b/src/mongo/db/query/sbe_shard_filter_test.cpp
new file mode 100644
index 00000000000..5f92418e9d4
--- /dev/null
+++ b/src/mongo/db/query/sbe_shard_filter_test.cpp
@@ -0,0 +1,221 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/exec/shard_filterer_mock.h"
+#include "mongo/db/query/query_solution.h"
+#include "mongo/db/query/sbe_stage_builder_test_fixture.h"
+#include "mongo/db/query/shard_filterer_factory_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+
+const NamespaceString kTestNss("TestDB", "TestColl");
+
+class SbeShardFilterTest : public SbeStageBuilderTestFixture {
+protected:
+ /**
+ * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer that always passes.
+ */
+ std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysPassShardFiltererFactory(
+ const BSONObj& shardKeyPattern) {
+ return std::make_unique<ShardFiltererFactoryMock>(
+ std::make_unique<ConstantFilterMock>(true, shardKeyPattern));
+ }
+
+ /**
+ * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer that always fails.
+ */
+ std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysFailShardFiltererFactory(
+ const BSONObj& shardKeyPattern) {
+ return std::make_unique<ShardFiltererFactoryMock>(
+ std::make_unique<ConstantFilterMock>(false, shardKeyPattern));
+ }
+
+ /**
+ * Makes a ShardFiltererFactoryInterface that produces a mock ShardFilterer to filter out docs
+ * containing all null values along the shard key.
+ */
+ std::unique_ptr<ShardFiltererFactoryInterface> makeAllNullShardKeyFiltererFactory(
+ const BSONObj& shardKeyPattern) {
+ return std::make_unique<ShardFiltererFactoryMock>(
+ std::make_unique<AllNullShardKeyFilterMock>(shardKeyPattern));
+ }
+
+
+ /**
+ * Makes a new QuerySolutionNode consisting of a ShardingFilterNode and a child VirtualScanNode.
+ */
+ std::unique_ptr<QuerySolutionNode> makeFilterVirtualScanTree(std::vector<BSONArray> docs) {
+ auto virtScan = std::make_unique<VirtualScanNode>(docs, false);
+ auto shardFilter = std::make_unique<ShardingFilterNode>();
+ shardFilter->children.push_back(virtScan.release());
+ return std::move(shardFilter);
+ }
+
+ /**
+ * Runs a unit test with a given shard filterer factory and asserts that the results match the
+ * expected docs.
+ */
+ void runTest(std::vector<BSONArray> docs,
+ const BSONArray& expected,
+ std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererFactory) {
+ // Construct a QuerySolutionNode consisting of a ShardingFilterNode with a single child
+ // VirtualScanNode.
+ auto shardFilter = makeFilterVirtualScanTree(docs);
+ auto querySolution = makeQuerySolution(std::move(shardFilter));
+
+ // Translate the QuerySolution to an sbe::PlanStage.
+ auto [resultSlots, stage, data] =
+ buildPlanStage(std::move(querySolution), false, std::move(shardFiltererFactory));
+
+ // Prepare the sbe::PlanStage for execution and collect all results.
+ auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots);
+ auto [resultsTag, resultsVal] = getAllResults(stage.get(), resultAccessors[0]);
+ sbe::value::ValueGuard resultGuard{resultsTag, resultsVal};
+
+ // Convert the expected results to an sbe value and assert results.
+ auto [expectedTag, expectedVal] = stage_builder::makeValue(expected);
+ sbe::value::ValueGuard expectedGuard{expectedTag, expectedVal};
+ ASSERT_TRUE(valueEquals(resultsTag, resultsVal, expectedTag, expectedVal));
+ }
+};
+
+TEST_F(SbeShardFilterTest, AlwaysPassFilter) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 2 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 3 << "b" << 2))};
+ auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2)
+ << BSON("a" << 2 << "b" << 2) << BSON("a" << 3 << "b" << 2));
+ runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, AlwaysFailFilter) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 2 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 3 << "b" << 2))};
+ auto expected = BSONArray();
+ runTest(docs, expected, makeAlwaysFailShardFiltererFactory(BSON("a" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, ArrayAlongLeafShardKeyGetsFiltered) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 2 << "b" << 2)),
+ BSON_ARRAY(BSON("a" << 3 << "b" << BSON_ARRAY(1 << 2)))};
+
+ auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2) << BSON("a" << 2 << "b" << 2));
+ runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1 << "b" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, TopLevelArrayShardKeyGetsFiltered) {
+ auto docs = std::vector<BSONArray>{
+ BSON_ARRAY(BSON("a" << BSON("b" << 1))),
+ BSON_ARRAY(BSON("a" << BSON("b" << 2))),
+ BSON_ARRAY(BSON("a" << BSON_ARRAY(BSON("b" << 1) << BSON("b" << 2))))};
+
+ auto expected = BSON_ARRAY(BSON("a" << BSON("b" << 1)) << BSON("a" << BSON("b" << 2)));
+ runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, ArrayAlongBiggerShardKeyGetsFiltered) {
+ auto docs = std::vector<BSONArray>{
+ BSON_ARRAY(BSON("a" << 1 << "b" << 2 << "c" << 3 << "d" << 4)),
+ BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 3 << "d" << 4)),
+ BSON_ARRAY(BSON("a" << BSON_ARRAY(1 << 2) << "b" << 2 << "c" << 3 << "d" << 4)),
+ BSON_ARRAY(BSON("a" << 3 << "b" << 2 << "c" << 3 << "d" << 4))};
+
+ auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 2 << "c" << 3 << "d" << 4)
+ << BSON("a" << 2 << "b" << 2 << "c" << 3 << "d" << 4)
+ << BSON("a" << 3 << "b" << 2 << "c" << 3 << "d" << 4));
+ runTest(
+ docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a" << 1 << "b" << 1 << "c" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, ArrayInDottedPathKeyGetsFiltered) {
+ auto docs =
+ std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << 1) << "c" << 2)),
+ BSON_ARRAY(BSON("a" << BSON("b" << 2) << "c" << 2)),
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON_ARRAY(1 << 2)) << "c" << 2)),
+ BSON_ARRAY(BSON("a" << BSON("b" << 3) << "c" << 2))};
+
+ auto expected = BSON_ARRAY(BSON("a" << BSON("b" << 1) << "c" << 2)
+ << BSON("a" << BSON("b" << 2) << "c" << 2)
+ << BSON("a" << BSON("b" << 3) << "c" << 2));
+ runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, ArrayAlongDeepDottedPathGetsFiltered) {
+ auto docs = std::vector<BSONArray>{
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << 1))))))),
+ BSON_ARRAY(BSON(
+ "a" << BSON(
+ "b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << BSON_ARRAY(1 << 2))))))))};
+ auto expected =
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << BSON("e" << BSON("f" << 1)))))));
+ runTest(docs, expected, makeAlwaysPassShardFiltererFactory(BSON("a.b.c.d.e.f" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, MissingFieldsAreFilledCorrectly) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 1 << "c" << 2)),
+ BSON_ARRAY(BSON("a" << 2 << "b" << 2 << "c" << 2)),
+ BSON_ARRAY(BSON("c" << 2))};
+
+ auto expected = BSON_ARRAY(BSON("a" << 1 << "b" << 1 << "c" << 2)
+ << BSON("a" << 2 << "b" << 2 << "c" << 2));
+ runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a" << 1 << "b" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, MissingFieldsDottedPathFilledCorrectly) {
+ auto docs =
+ std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << 1))),
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))};
+
+ auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))));
+ runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, MissingFieldsAtTopDottedPathFilledCorrectly) {
+ auto docs =
+ std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1)),
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))};
+
+ auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))));
+ runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1)));
+}
+
+TEST_F(SbeShardFilterTest, MissingFieldsAtBottomDottedPathFilledCorrectly) {
+ auto docs =
+ std::vector<BSONArray>{BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << 1)))),
+ BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))))};
+
+ auto expected = BSON_ARRAY(BSON("a" << BSON("b" << BSON("c" << BSON("d" << 1)))));
+ runTest(docs, expected, makeAllNullShardKeyFiltererFactory(BSON("a.b.c.d" << 1)));
+}
+} // namespace mongo
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index 71adc9906b1..320e8289c1d 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -46,6 +46,7 @@
#include "mongo/db/exec/sbe/stages/traverse.h"
#include "mongo/db/exec/sbe/stages/union.h"
#include "mongo/db/exec/sbe/stages/unique.h"
+#include "mongo/db/exec/shard_filterer.h"
#include "mongo/db/fts/fts_index_format.h"
#include "mongo/db/fts/fts_query_impl.h"
#include "mongo/db/fts/fts_spec.h"
@@ -56,6 +57,7 @@
#include "mongo/db/query/sbe_stage_builder_index_scan.h"
#include "mongo/db/query/sbe_stage_builder_projection.h"
#include "mongo/db/query/util/make_data_structure.h"
+#include "mongo/db/s/collection_sharding_state.h"
namespace mongo::stage_builder {
std::unique_ptr<sbe::RuntimeEnvironment> makeRuntimeEnvironment(
@@ -119,10 +121,12 @@ SlotBasedStageBuilder::SlotBasedStageBuilder(OperationContext* opCtx,
const CanonicalQuery& cq,
const QuerySolution& solution,
PlanYieldPolicySBE* yieldPolicy,
- bool needsTrialRunProgressTracker)
+ bool needsTrialRunProgressTracker,
+ ShardFiltererFactoryInterface* shardFiltererFactory)
: StageBuilder(opCtx, collection, cq, solution),
_yieldPolicy(yieldPolicy),
- _data(makeRuntimeEnvironment(_opCtx, &_slotIdGenerator)) {
+ _data(makeRuntimeEnvironment(_opCtx, &_slotIdGenerator)),
+ _shardFiltererFactory(shardFiltererFactory) {
if (needsTrialRunProgressTracker) {
const auto maxNumResults{trial_period::getTrialPeriodNumToReturn(_cq)};
@@ -140,6 +144,11 @@ SlotBasedStageBuilder::SlotBasedStageBuilder(OperationContext* opCtx,
_data.shouldTrackResumeToken = csn->requestResumeToken;
_data.shouldUseTailableScan = csn->tailable;
}
+
+ if (auto node = getNodeByType(solution.root(), STAGE_VIRTUAL_SCAN)) {
+ auto vsn = static_cast<const VirtualScanNode*>(node);
+ _shouldProduceRecordIdSlot = vsn->hasRecordId;
+ }
}
std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolutionNode* root) {
@@ -147,21 +156,23 @@ std::unique_ptr<sbe::PlanStage> SlotBasedStageBuilder::build(const QuerySolution
invariant(!_buildHasStarted);
_buildHasStarted = true;
- // We always produce a 'resultSlot' and a 'recordIdSlot'. If the solution contains a
- // CollectionScanNode with the 'shouldTrackLatestOplogTimestamp' flag set to true, then we
- // will also produce an 'oplogTsSlot'.
+ // We always produce a 'resultSlot' and conditionally produce a 'recordIdSlot' based on the
+ // 'shouldProduceRecordIdSlot'. If the solution contains a CollectionScanNode with the
+ // 'shouldTrackLatestOplogTimestamp' flag set to true, then we will also produce an
+ // 'oplogTsSlot'.
PlanStageReqs reqs;
reqs.set(kResult);
- reqs.set(kRecordId);
+ reqs.setIf(kRecordId, _shouldProduceRecordIdSlot);
reqs.setIf(kOplogTs, _data.shouldTrackLatestOplogTimestamp);
// Build the SBE plan stage tree.
auto [stage, outputs] = build(root, reqs);
- // Assert that we produced a 'resultSlot' and a 'recordIdSlot'. Also assert that we produced
- // an 'oplogTsSlot' if it's needed.
+ // Assert that we produced a 'resultSlot' and that we prouced a 'recordIdSlot' if the
+ // 'shouldProduceRecordIdSlot' flag was set. Also assert that we produced an 'oplogTsSlot' if
+ // it's needed.
invariant(outputs.has(kResult));
- invariant(outputs.has(kRecordId));
+ invariant(!_shouldProduceRecordIdSlot || outputs.has(kRecordId));
invariant(!_data.shouldTrackLatestOplogTimestamp || outputs.has(kOplogTs));
_data.outputs = std::move(outputs);
@@ -203,6 +214,10 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildVirtualScan(
const QuerySolutionNode* root, const PlanStageReqs& reqs) {
auto vsn = static_cast<const VirtualScanNode*>(root);
+ invariant(!reqs.getIndexKeyBitset());
+
+ // Virtual scans cannot produce an oplogTsSlot, so assert that the caller doesn't need it.
+ invariant(!reqs.has(kOplogTs));
auto [inputTag, inputVal] = sbe::value::makeNewArray();
sbe::value::ValueGuard inputGuard{inputTag, inputVal};
@@ -221,11 +236,12 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
if (vsn->hasRecordId) {
invariant(scanSlots.size() == 2);
- outputs.set(PlanStageSlots::kRecordId, scanSlots[0]);
- outputs.set(PlanStageSlots::kResult, scanSlots[1]);
+ outputs.set(kRecordId, scanSlots[0]);
+ outputs.set(kResult, scanSlots[1]);
} else {
invariant(scanSlots.size() == 1);
- outputs.set(PlanStageSlots::kResult, scanSlots[0]);
+ invariant(!reqs.has(kRecordId));
+ outputs.set(kResult, scanSlots[0]);
}
return {std::move(scanStage), std::move(outputs)};
@@ -999,6 +1015,98 @@ SlotBasedStageBuilder::makeUnionForTailableCollScan(const QuerySolutionNode* roo
return {std::move(unionStage), std::move(outputs)};
}
+std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::buildShardFilter(
+ const QuerySolutionNode* root, const PlanStageReqs& reqs) {
+ using namespace std::literals;
+
+ const auto filterNode = static_cast<const ShardingFilterNode*>(root);
+
+ uassert(5071201,
+ "STAGE_SHARD_FILTER is curently only supported in SBE for collection scan plans",
+ filterNode->children[0]->getType() == StageType::STAGE_COLLSCAN ||
+ filterNode->children[0]->getType() == StageType::STAGE_VIRTUAL_SCAN);
+
+ auto childReqs = reqs.copy().set(kResult);
+ auto [stage, outputs] = build(filterNode->children[0], childReqs);
+
+ // If we're sharded make sure that we don't return data that isn't owned by the shard. This
+ // situation can occur when pending documents from in-progress migrations are inserted and when
+ // there are orphaned documents from aborted migrations. To check if the document is owned by
+ // the shard, we need to own a 'ShardFilterer', and extract the document's shard key as a
+ // BSONObj.
+ auto shardFilterer = _shardFiltererFactory->makeShardFilterer(_opCtx);
+
+ // Build an expression to extract the shard key from the document based on the shard key
+ // pattern. To do this, we iterate over the shard key pattern parts and build nested 'getField'
+ // expressions. This will handle single-element paths, and dotted paths for each shard key part.
+ sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> projections;
+ sbe::value::SlotVector fieldSlots;
+ std::vector<std::string> projectFields;
+ std::unique_ptr<sbe::EExpression> bindShardKeyPart;
+
+ BSONObjIterator keyPatternIter(shardFilterer->getKeyPattern().toBSON());
+ while (auto keyPatternElem = keyPatternIter.next()) {
+ auto fieldRef = FieldRef{keyPatternElem.fieldNameStringData()};
+ fieldSlots.push_back(_slotIdGenerator.generate());
+ projectFields.push_back(fieldRef.dottedField().toString());
+
+ auto currentFieldSlot = sbe::makeE<sbe::EVariable>(outputs.get(kResult));
+ auto shardKeyBinding =
+ generateShardKeyBinding(fieldRef, _frameIdGenerator, std::move(currentFieldSlot), 0);
+
+ projections.emplace(fieldSlots.back(), std::move(shardKeyBinding));
+ }
+
+ auto shardKeySlot{_slotIdGenerator.generate()};
+
+ // Build an object which will hold a flattened shard key from the projections above.
+ auto shardKeyObjStage = sbe::makeS<sbe::MakeObjStage>(
+ sbe::makeS<sbe::ProjectStage>(std::move(stage), std::move(projections), root->nodeId()),
+ shardKeySlot,
+ boost::none,
+ std::vector<std::string>{},
+ projectFields,
+ fieldSlots,
+ true,
+ false,
+ root->nodeId());
+
+ // Build a project stage that checks if any of the fieldSlots for the shard key parts are an
+ // Array which is represented by Nothing.
+ invariant(fieldSlots.size() > 0);
+ auto arrayChecks = makeNot(sbe::makeE<sbe::EFunction>(
+ "exists", sbe::makeEs(sbe::makeE<sbe::EVariable>(fieldSlots[0]))));
+ for (size_t ind = 1; ind < fieldSlots.size(); ++ind) {
+ arrayChecks = sbe::makeE<sbe::EPrimBinary>(
+ sbe::EPrimBinary::Op::logicOr,
+ std::move(arrayChecks),
+ makeNot(sbe::makeE<sbe::EFunction>(
+ "exists", sbe::makeEs(sbe::makeE<sbe::EVariable>(fieldSlots[ind])))));
+ }
+ arrayChecks = sbe::makeE<sbe::EIf>(std::move(arrayChecks),
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0),
+ sbe::makeE<sbe::EVariable>(shardKeySlot));
+
+ auto finalShardKeySlot{_slotIdGenerator.generate()};
+
+ auto finalShardKeyObjStage = makeProjectStage(
+ std::move(shardKeyObjStage), root->nodeId(), finalShardKeySlot, std::move(arrayChecks));
+
+ // Build a 'FilterStage' to skip over documents that don't belong to the shard. Shard membership
+ // of the document is checked by invoking 'shardFilter' with the owned 'ShardFilterer' along
+ // with the shard key that sits in the 'finalShardKeySlot' of 'MakeObjStage'.
+ auto shardFilterFn = sbe::makeE<sbe::EFunction>(
+ "shardFilter"sv,
+ sbe::makeEs(sbe::makeE<sbe::EConstant>(
+ sbe::value::TypeTags::shardFilterer,
+ sbe::value::bitcastFrom<ShardFilterer*>(shardFilterer.release())),
+ sbe::makeE<sbe::EVariable>(finalShardKeySlot)));
+
+ return {sbe::makeS<sbe::FilterStage<false>>(
+ std::move(finalShardKeyObjStage), std::move(shardFilterFn), root->nodeId()),
+ std::move(outputs)};
+}
+
// Returns a non-null pointer to the root of a plan tree, or a non-OK status if the PlanStage tree
// could not be constructed.
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder::build(
@@ -1024,7 +1132,8 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
{STAGE_TEXT, &SlotBasedStageBuilder::buildText},
{STAGE_RETURN_KEY, &SlotBasedStageBuilder::buildReturnKey},
{STAGE_EOF, &SlotBasedStageBuilder::buildEof},
- {STAGE_SORT_MERGE, &SlotBasedStageBuilder::buildSortMerge}};
+ {STAGE_SORT_MERGE, &SlotBasedStageBuilder::buildSortMerge},
+ {STAGE_SHARDING_FILTER, &SlotBasedStageBuilder::buildShardFilter}};
uassert(4822884,
str::stream() << "Can't build exec tree for node: " << root->toString(),
diff --git a/src/mongo/db/query/sbe_stage_builder.h b/src/mongo/db/query/sbe_stage_builder.h
index d4b618dba34..ee13f494280 100644
--- a/src/mongo/db/query/sbe_stage_builder.h
+++ b/src/mongo/db/query/sbe_stage_builder.h
@@ -35,6 +35,7 @@
#include "mongo/db/exec/trial_period_utils.h"
#include "mongo/db/exec/trial_run_progress_tracker.h"
#include "mongo/db/query/plan_yield_policy_sbe.h"
+#include "mongo/db/query/shard_filterer_factory_interface.h"
#include "mongo/db/query/stage_builder.h"
namespace mongo::stage_builder {
@@ -255,7 +256,8 @@ public:
const CanonicalQuery& cq,
const QuerySolution& solution,
PlanYieldPolicySBE* yieldPolicy,
- bool needsTrialRunProgressTracker);
+ bool needsTrialRunProgressTracker,
+ ShardFiltererFactoryInterface* shardFilterer);
std::unique_ptr<sbe::PlanStage> build(const QuerySolutionNode* root) final;
@@ -324,6 +326,9 @@ private:
std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> makeUnionForTailableCollScan(
const QuerySolutionNode* root, const PlanStageReqs& reqs);
+ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> buildShardFilter(
+ const QuerySolutionNode* root, const PlanStageReqs& reqs);
+
sbe::value::SlotIdGenerator _slotIdGenerator;
sbe::value::FrameIdGenerator _frameIdGenerator;
sbe::value::SpoolIdGenerator _spoolIdGenerator;
@@ -335,5 +340,9 @@ private:
PlanStageData _data;
bool _buildHasStarted{false};
+ bool _shouldProduceRecordIdSlot{true};
+
+ // A factory to construct shard filters.
+ ShardFiltererFactoryInterface* _shardFiltererFactory;
};
} // namespace mongo::stage_builder
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index 7e7c45ded2a..acac413ebd9 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -151,6 +151,59 @@ std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpre
sbe::value::bitcastFrom<bool>(false))));
}
+std::unique_ptr<sbe::EExpression> makeFillEmptyNull(std::unique_ptr<sbe::EExpression> e) {
+ using namespace std::literals;
+ return sbe::makeE<sbe::EFunction>(
+ "fillEmpty"sv,
+ sbe::makeEs(std::move(e), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)));
+}
+
+std::unique_ptr<sbe::EExpression> makeNothingArrayCheck(
+ std::unique_ptr<sbe::EExpression> isArrayInput, std::unique_ptr<sbe::EExpression> otherwise) {
+ using namespace std::literals;
+ return sbe::makeE<sbe::EIf>(
+ sbe::makeE<sbe::EFunction>("isArray"sv, sbe::makeEs(std::move(isArrayInput))),
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0),
+ std::move(otherwise));
+}
+
+std::unique_ptr<sbe::EExpression> generateShardKeyBinding(
+ const FieldRef& keyPatternField,
+ sbe::value::FrameIdGenerator& frameIdGenerator,
+ std::unique_ptr<sbe::EExpression> inputExpr,
+ int level) {
+ using namespace std::literals;
+ invariant(level >= 0);
+
+ auto makeGetFieldKeyPattern = [&](std::unique_ptr<sbe::EExpression> slot) {
+ return makeFillEmptyNull(sbe::makeE<sbe::EFunction>(
+ "getField"sv,
+ sbe::makeEs(std::move(slot), sbe::makeE<sbe::EConstant>([&]() {
+ const auto fieldName = keyPatternField[level];
+ return std::string_view{fieldName.rawData(), fieldName.size()};
+ }()))));
+ };
+
+ if (level == keyPatternField.numParts() - 1) {
+ auto frameId = frameIdGenerator.generate();
+ auto bindSlot = sbe::makeE<sbe::EVariable>(frameId, 0);
+ return sbe::makeE<sbe::ELocalBind>(
+ frameId,
+ sbe::makeEs(makeGetFieldKeyPattern(std::move(inputExpr))),
+ makeNothingArrayCheck(bindSlot->clone(), bindSlot->clone()));
+ }
+
+ auto frameId = frameIdGenerator.generate();
+ auto nextSlot = sbe::makeE<sbe::EVariable>(frameId, 0);
+ auto shardKeyBinding =
+ generateShardKeyBinding(keyPatternField, frameIdGenerator, nextSlot->clone(), level + 1);
+
+ return sbe::makeE<sbe::ELocalBind>(
+ frameId,
+ sbe::makeEs(makeGetFieldKeyPattern(inputExpr->clone())),
+ makeNothingArrayCheck(nextSlot->clone(), std::move(shardKeyBinding)));
+}
+
EvalStage makeLimitCoScanStage(PlanNodeId planNodeId, long long limit) {
return {makeLimitCoScanTree(planNodeId, limit), sbe::makeSV()};
}
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index 454665e65a7..91e201f5bc7 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -165,6 +165,36 @@ inline std::unique_ptr<sbe::EExpression> makeFunction(std::string_view name, Arg
}
/**
+ * Check if expression returns Nothing and return null if so. Otherwise, return the
+ * expression.
+ */
+std::unique_ptr<sbe::EExpression> makeFillEmptyNull(std::unique_ptr<sbe::EExpression> e);
+
+/**
+ * Check if expression returns an array and return Nothing if so. Otherwise, return the expression.
+ */
+std::unique_ptr<sbe::EExpression> makeNothingArrayCheck(
+ std::unique_ptr<sbe::EExpression> isArrayInput, std::unique_ptr<sbe::EExpression> otherwise);
+
+/**
+ * Creates an expression to extract a shard key part from inputExpr. The generated expression is a
+ * let binding that binds a getField expression to extract the shard key part value from the
+ * inputExpr. The entire let binding evaluates to a constant expression carrying the Nothing value
+ * if the binding is an array. Otherwise, it evaluates to a fillEmpty null expression. Here is an
+ * example expression generated from this function for a shard key pattern {'a.b': 1}:
+ *
+ * let [l1.0 = getField (s1, "a")]
+ * if (isArray (l1.0), NOTHING,
+ * let [l2.0 = getField (l1.0, "b")]
+ * if (isArray (l2.0), NOTHING, fillEmpty (l2.0, null)))
+ */
+std::unique_ptr<sbe::EExpression> generateShardKeyBinding(
+ const FieldRef& keyPatternField,
+ sbe::value::FrameIdGenerator& frameIdGenerator,
+ std::unique_ptr<sbe::EExpression> inputExpr,
+ int level);
+
+/**
* If given 'EvalExpr' already contains a slot, simply returns it. Otherwise, allocates a new slot
* and creates project stage to assign expression to this new slot. After that, new slot and project
* stage are returned.
diff --git a/src/mongo/db/query/sbe_stage_builder_test.cpp b/src/mongo/db/query/sbe_stage_builder_test.cpp
index c830f33bdcb..a35066c6a28 100644
--- a/src/mongo/db/query/sbe_stage_builder_test.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_test.cpp
@@ -29,15 +29,23 @@
#include "mongo/platform/basic.h"
+#include "mongo/db/exec/shard_filterer_mock.h"
#include "mongo/db/query/query_solution.h"
#include "mongo/db/query/sbe_stage_builder_test_fixture.h"
+#include "mongo/db/query/shard_filterer_factory_mock.h"
#include "mongo/unittest/unittest.h"
namespace mongo {
-using SBEStageBuilderTest = SBEStageBuilderTestFixture;
+class SbeStageBuilderTest : public SbeStageBuilderTestFixture {
+protected:
+ std::unique_ptr<ShardFiltererFactoryInterface> makeAlwaysPassShardFiltererInterface() {
+ return std::make_unique<ShardFiltererFactoryMock>(
+ std::make_unique<ConstantFilterMock>(true, BSONObj{}));
+ }
+};
-TEST_F(SBEStageBuilderTest, TestVirtualScan) {
+TEST_F(SbeStageBuilderTest, TestVirtualScan) {
auto docs = std::vector<BSONArray>{BSON_ARRAY(int64_t{0} << BSON("a" << 1 << "b" << 2)),
BSON_ARRAY(int64_t{1} << BSON("a" << 2 << "b" << 2)),
BSON_ARRAY(int64_t{2} << BSON("a" << 3 << "b" << 2))};
@@ -51,7 +59,9 @@ TEST_F(SBEStageBuilderTest, TestVirtualScan) {
ASSERT_EQ(querySolution->root()->nodeId(), 1);
// Translate the QuerySolution tree to an sbe::PlanStage.
- auto [resultSlots, stage, data] = buildPlanStage(std::move(querySolution), true);
+ auto shardFiltererInterface = makeAlwaysPassShardFiltererInterface();
+ auto [resultSlots, stage, data] =
+ buildPlanStage(std::move(querySolution), true, std::move(shardFiltererInterface));
auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots);
int64_t index = 0;
@@ -70,7 +80,7 @@ TEST_F(SBEStageBuilderTest, TestVirtualScan) {
ASSERT_EQ(index, 3);
}
-TEST_F(SBEStageBuilderTest, TestLimitOneVirtualScan) {
+TEST_F(SbeStageBuilderTest, TestLimitOneVirtualScan) {
auto docs = std::vector<BSONArray>{BSON_ARRAY(int64_t{0} << BSON("a" << 1 << "b" << 2)),
BSON_ARRAY(int64_t{1} << BSON("a" << 2 << "b" << 2)),
BSON_ARRAY(int64_t{2} << BSON("a" << 3 << "b" << 2))};
@@ -86,7 +96,9 @@ TEST_F(SBEStageBuilderTest, TestLimitOneVirtualScan) {
auto querySolution = makeQuerySolution(std::move(limitNode));
// Translate the QuerySolution tree to an sbe::PlanStage.
- auto [resultSlots, stage, data] = buildPlanStage(std::move(querySolution), true);
+ auto shardFiltererInterface = makeAlwaysPassShardFiltererInterface();
+ auto [resultSlots, stage, data] =
+ buildPlanStage(std::move(querySolution), true, std::move(shardFiltererInterface));
// Prepare the sbe::PlanStage for execution.
auto resultAccessors = prepareTree(&data.ctx, stage.get(), resultSlots);
diff --git a/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp b/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp
index 31b8b6005d7..97dc4ffcd79 100644
--- a/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_test_fixture.cpp
@@ -37,7 +37,7 @@
#include "mongo/db/query/sbe_stage_builder_test_fixture.h"
namespace mongo {
-std::unique_ptr<QuerySolution> SBEStageBuilderTestFixture::makeQuerySolution(
+std::unique_ptr<QuerySolution> SbeStageBuilderTestFixture::makeQuerySolution(
std::unique_ptr<QuerySolutionNode> root) {
auto querySoln = std::make_unique<QuerySolution>();
querySoln->setRoot(std::move(root));
@@ -45,8 +45,10 @@ std::unique_ptr<QuerySolution> SBEStageBuilderTestFixture::makeQuerySolution(
}
std::tuple<sbe::value::SlotVector, std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>
-SBEStageBuilderTestFixture::buildPlanStage(std::unique_ptr<QuerySolution> querySolution,
- bool hasRecordId) {
+SbeStageBuilderTestFixture::buildPlanStage(
+ std::unique_ptr<QuerySolution> querySolution,
+ bool hasRecordId,
+ std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererInterface) {
auto qr = std::make_unique<QueryRequest>(_nss);
const boost::intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest(_nss));
auto statusWithCQ = CanonicalQuery::canonicalize(opCtx(), std::move(qr), expCtx);
@@ -57,7 +59,8 @@ SBEStageBuilderTestFixture::buildPlanStage(std::unique_ptr<QuerySolution> queryS
*statusWithCQ.getValue(),
*querySolution,
nullptr /* YieldPolicy */,
- false};
+ false,
+ shardFiltererInterface.get()};
auto stage = builder.build(querySolution->root());
auto data = builder.getPlanStageData();
diff --git a/src/mongo/db/query/sbe_stage_builder_test_fixture.h b/src/mongo/db/query/sbe_stage_builder_test_fixture.h
index d38382e8c71..840ab113a4b 100644
--- a/src/mongo/db/query/sbe_stage_builder_test_fixture.h
+++ b/src/mongo/db/query/sbe_stage_builder_test_fixture.h
@@ -34,11 +34,12 @@
#include "mongo/db/exec/sbe/values/slot.h"
#include "mongo/db/exec/sbe/values/value.h"
#include "mongo/db/query/query_solution.h"
+#include "mongo/db/query/shard_filterer_factory_interface.h"
#include "mongo/unittest/unittest.h"
namespace mongo {
/**
- * SBEStageBuilderTestFixture is a unittest fixture that can be used to facilitate testing the
+ * SbeStageBuilderTestFixture is a unittest fixture that can be used to facilitate testing the
* translation of a QuerySolution tree to an sbe PlanStage tree.
*
* The main mechanism that enables the whole sbe::PlanStage tree to be exercised under unittests is
@@ -51,9 +52,9 @@ namespace mongo {
* prepare the sbe::PlanStage tree. The sbe::PlanStageData returned from buildPlanStage() must be
* kept alive across buildPlanStage(), prepareTree() and execution of the plan.
*/
-class SBEStageBuilderTestFixture : public sbe::PlanStageTestFixture {
+class SbeStageBuilderTestFixture : public sbe::PlanStageTestFixture {
public:
- SBEStageBuilderTestFixture() = default;
+ SbeStageBuilderTestFixture() = default;
/**
* Makes a QuerySolution from a QuerySolutionNode.
@@ -66,11 +67,16 @@ public:
* SlotVector. If hasRecordId is 'true' then the returned SlotVector will carry a SlotId in the
* 0th position for the RecordId and a SlotId for the BSONObj representation of the document in
* the 1st position. Otherwise, if hasRecordId is 'false', the SlotVector will contain a single
- * SlotId for the BSONObj representation of the document.
+ * SlotId for the BSONObj representation of the document. A real or mock
+ * ShardFiltererFactoryInterface must be provided so the sbe SlotBasedStageBuilder can build and
+ * utilize a ShardFilterer instance during translation of a ShardingFilterNode.
*/
- std::
- tuple<sbe::value::SlotVector, std::unique_ptr<sbe::PlanStage>, stage_builder::PlanStageData>
- buildPlanStage(std::unique_ptr<QuerySolution> querySolution, bool hasRecordId);
+ std::tuple<sbe::value::SlotVector,
+ std::unique_ptr<sbe::PlanStage>,
+ stage_builder::PlanStageData>
+ buildPlanStage(std::unique_ptr<QuerySolution> querySolution,
+ bool hasRecordId,
+ std::unique_ptr<ShardFiltererFactoryInterface> shardFiltererFactoryInterface);
private:
const NamespaceString _nss = NamespaceString{"testdb.sbe_stage_builder"};
diff --git a/src/mongo/db/query/shard_filterer_factory_impl.cpp b/src/mongo/db/query/shard_filterer_factory_impl.cpp
new file mode 100644
index 00000000000..0db775c411e
--- /dev/null
+++ b/src/mongo/db/query/shard_filterer_factory_impl.cpp
@@ -0,0 +1,44 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/exec/shard_filterer_impl.h"
+#include "mongo/db/query/shard_filterer_factory_impl.h"
+#include "mongo/db/s/collection_sharding_state.h"
+
+namespace mongo {
+
+std::unique_ptr<ShardFilterer> ShardFiltererFactoryImpl::makeShardFilterer(
+ OperationContext* opCtx) const {
+ auto css = CollectionShardingState::get(opCtx, _collection->ns());
+ return std::make_unique<ShardFiltererImpl>(css->getOwnershipFilter(
+ opCtx, CollectionShardingState::OrphanCleanupPolicy::kDisallowOrphanCleanup));
+}
+} // namespace mongo
diff --git a/src/mongo/db/query/shard_filterer_factory_impl.h b/src/mongo/db/query/shard_filterer_factory_impl.h
new file mode 100644
index 00000000000..957b8154dba
--- /dev/null
+++ b/src/mongo/db/query/shard_filterer_factory_impl.h
@@ -0,0 +1,49 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/catalog/collection.h"
+#include "mongo/db/query/shard_filterer_factory_interface.h"
+
+namespace mongo {
+
+/**
+ * An implementation of ShardFiltererFactoryInterface.
+ */
+class ShardFiltererFactoryImpl : public ShardFiltererFactoryInterface {
+public:
+ ShardFiltererFactoryImpl(const CollectionPtr& collection) : _collection(collection) {}
+
+ std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const override;
+
+private:
+ const CollectionPtr& _collection;
+};
+} // namespace mongo
diff --git a/src/mongo/db/query/shard_filterer_factory_interface.h b/src/mongo/db/query/shard_filterer_factory_interface.h
new file mode 100644
index 00000000000..acb9e8e81fa
--- /dev/null
+++ b/src/mongo/db/query/shard_filterer_factory_interface.h
@@ -0,0 +1,46 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/exec/shard_filterer.h"
+#include "mongo/db/operation_context.h"
+
+namespace mongo {
+
+/**
+ * An interface that can be used to construct a ShardFilterer.
+ */
+class ShardFiltererFactoryInterface {
+public:
+ virtual ~ShardFiltererFactoryInterface() = default;
+
+ virtual std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const = 0;
+};
+} // namespace mongo
diff --git a/src/mongo/db/query/shard_filterer_factory_mock.cpp b/src/mongo/db/query/shard_filterer_factory_mock.cpp
new file mode 100644
index 00000000000..2e28150d1ec
--- /dev/null
+++ b/src/mongo/db/query/shard_filterer_factory_mock.cpp
@@ -0,0 +1,43 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/query/shard_filterer_factory_mock.h"
+
+namespace mongo {
+
+ShardFiltererFactoryMock::ShardFiltererFactoryMock(std::unique_ptr<ShardFilterer> shardFilterer)
+ : _shardFilterer(std::move(shardFilterer)) {}
+
+std::unique_ptr<ShardFilterer> ShardFiltererFactoryMock::makeShardFilterer(
+ OperationContext* opCtx) const {
+ return _shardFilterer->clone();
+}
+} // namespace mongo
diff --git a/src/mongo/db/query/shard_filterer_factory_mock.h b/src/mongo/db/query/shard_filterer_factory_mock.h
new file mode 100644
index 00000000000..356c1cc9774
--- /dev/null
+++ b/src/mongo/db/query/shard_filterer_factory_mock.h
@@ -0,0 +1,52 @@
+/**
+ * Copyright (C) 2020-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/exec/shard_filterer.h"
+#include "mongo/db/query/shard_filterer_factory_interface.h"
+
+namespace mongo {
+
+/**
+ * An implementation of ShardFiltererFactoryInterface for unit testing.
+ */
+class ShardFiltererFactoryMock : public ShardFiltererFactoryInterface {
+public:
+ ShardFiltererFactoryMock(std::unique_ptr<ShardFilterer> shardFilter);
+
+ /*
+ * Makes a new mock ShardFilterer.
+ */
+ std::unique_ptr<ShardFilterer> makeShardFilterer(OperationContext* opCtx) const override;
+
+private:
+ std::unique_ptr<ShardFilterer> _shardFilterer;
+};
+} // namespace mongo
diff --git a/src/mongo/db/query/stage_builder_util.cpp b/src/mongo/db/query/stage_builder_util.cpp
index 58bda2900ca..f318bb127eb 100644
--- a/src/mongo/db/query/stage_builder_util.cpp
+++ b/src/mongo/db/query/stage_builder_util.cpp
@@ -34,6 +34,7 @@
#include "mongo/db/query/classic_stage_builder.h"
#include "mongo/db/query/plan_yield_policy.h"
#include "mongo/db/query/sbe_stage_builder.h"
+#include "mongo/db/query/shard_filterer_factory_impl.h"
namespace mongo::stage_builder {
std::unique_ptr<PlanStage> buildClassicExecutableTree(OperationContext* opCtx,
@@ -69,8 +70,15 @@ buildSlotBasedExecutableTree(OperationContext* opCtx,
auto sbeYieldPolicy = dynamic_cast<PlanYieldPolicySBE*>(yieldPolicy);
invariant(sbeYieldPolicy);
- auto builder = std::make_unique<SlotBasedStageBuilder>(
- opCtx, collection, cq, solution, sbeYieldPolicy, needsTrialRunProgressTracker);
+ auto shardFilterer = std::make_unique<ShardFiltererFactoryImpl>(collection);
+
+ auto builder = std::make_unique<SlotBasedStageBuilder>(opCtx,
+ collection,
+ cq,
+ solution,
+ sbeYieldPolicy,
+ needsTrialRunProgressTracker,
+ shardFilterer.get());
auto root = builder->build(solution.root());
auto data = builder->getPlanStageData();