diff options
author | Svilen Mihaylov <svilen.mihaylov@mongodb.com> | 2022-10-26 22:49:00 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-10-27 00:11:42 +0000 |
commit | 108528c1fdc795fa1f9571f2073668d4fc7f05c3 (patch) | |
tree | b4385a2c5ea7fcf940a160ca7e7ba5cc9d9f9869 /src/mongo/db/query | |
parent | 3a31638a3c0bc80fe03faf0bafba5aefaefc2b84 (diff) | |
download | mongo-108528c1fdc795fa1f9571f2073668d4fc7f05c3.tar.gz |
SERVER-70749 [CQF] Fix reference tracker behavior for join algorithms
Diffstat (limited to 'src/mongo/db/query')
8 files changed, 253 insertions, 102 deletions
diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.cpp b/src/mongo/db/query/optimizer/cascades/enforcers.cpp index a218043c881..801ae716202 100644 --- a/src/mongo/db/query/optimizer/cascades/enforcers.cpp +++ b/src/mongo/db/query/optimizer/cascades/enforcers.cpp @@ -67,14 +67,12 @@ public: PropEnforcerVisitor(const GroupIdType groupId, const Metadata& metadata, const RIDProjectionsMap& ridProjections, - PrefixId& prefixId, PhysRewriteQueue& queue, const PhysProps& physProps, const LogicalProps& logicalProps) : _groupId(groupId), _metadata(metadata), _ridProjections(ridProjections), - _prefixId(prefixId), _queue(queue), _physProps(physProps), _logicalProps(logicalProps) {} @@ -148,7 +146,8 @@ public: return; } - if (prop.getDistributionAndProjections()._type == DistributionType::UnknownPartitioning) { + const auto& requiredDistrAndProj = prop.getDistributionAndProjections(); + if (requiredDistrAndProj._type == DistributionType::UnknownPartitioning) { // Cannot exchange into unknown partitioning. return; } @@ -163,22 +162,22 @@ public: return; } - const auto& distributions = + const auto& availableDistrs = getPropertyConst<DistributionAvailability>(_logicalProps).getDistributionSet(); - for (const auto& distribution : distributions) { - if (distribution == prop.getDistributionAndProjections()) { + for (const auto& availableDistr : availableDistrs) { + if (availableDistr == requiredDistrAndProj) { // Same distribution. continue; } - if (distribution._type == DistributionType::Replicated) { + if (availableDistr._type == DistributionType::Replicated) { // Cannot switch "away" from replicated distribution. continue; } PhysProps childProps = _physProps; - setPropertyOverwrite<DistributionRequirement>(childProps, distribution); + setPropertyOverwrite<DistributionRequirement>(childProps, availableDistr); - addProjectionsToProperties(childProps, distribution._projectionNames); + addProjectionsToProperties(childProps, requiredDistrAndProj._projectionNames); getProperty<DistributionRequirement>(childProps).setDisableExchanges(true); ABT enforcer = make<ExchangeNode>(prop, make<MemoLogicalDelegatorNode>(_groupId)); @@ -261,7 +260,6 @@ private: // We don't own any of those. const Metadata& _metadata; const RIDProjectionsMap& _ridProjections; - PrefixId& _prefixId; PhysRewriteQueue& _queue; const PhysProps& _physProps; const LogicalProps& _logicalProps; @@ -270,12 +268,10 @@ private: void addEnforcers(const GroupIdType groupId, const Metadata& metadata, const RIDProjectionsMap& ridProjections, - PrefixId& prefixId, PhysRewriteQueue& queue, const PhysProps& physProps, const LogicalProps& logicalProps) { - PropEnforcerVisitor visitor( - groupId, metadata, ridProjections, prefixId, queue, physProps, logicalProps); + PropEnforcerVisitor visitor(groupId, metadata, ridProjections, queue, physProps, logicalProps); for (const auto& entry : physProps) { entry.second.visit(visitor); } diff --git a/src/mongo/db/query/optimizer/cascades/enforcers.h b/src/mongo/db/query/optimizer/cascades/enforcers.h index 855d0be4d4a..15bd4d139dd 100644 --- a/src/mongo/db/query/optimizer/cascades/enforcers.h +++ b/src/mongo/db/query/optimizer/cascades/enforcers.h @@ -40,7 +40,6 @@ namespace mongo::optimizer::cascades { void addEnforcers(GroupIdType groupId, const Metadata& metadata, const RIDProjectionsMap& ridProjections, - PrefixId& prefixId, PhysRewriteQueue& queue, const properties::PhysProps& physProps, const properties::LogicalProps& logicalProps); diff --git a/src/mongo/db/query/optimizer/cascades/implementers.cpp b/src/mongo/db/query/optimizer/cascades/implementers.cpp index 8e3b2d36c8c..6039ac3cb43 100644 --- a/src/mongo/db/query/optimizer/cascades/implementers.cpp +++ b/src/mongo/db/query/optimizer/cascades/implementers.cpp @@ -851,43 +851,58 @@ public: node.getRightChild()); }; - optimizeFn(); + const auto& leftDistributions = + getPropertyConst<DistributionAvailability>(leftLogicalProps).getDistributionSet(); + const auto& rightDistributions = + getPropertyConst<DistributionAvailability>(rightLogicalProps).getDistributionSet(); + + const bool leftDistrOK = leftDistributions.count(distrAndProjections) > 0; + const bool rightDistrOK = rightDistributions.count(distrAndProjections) > 0; + const bool seekWithRoundRobin = !isIndex && + distrAndProjections._type == DistributionType::RoundRobin && + rightDistributions.count(DistributionType::UnknownPartitioning) > 0; + + if (leftDistrOK && (rightDistrOK || seekWithRoundRobin)) { + // If we are not changing the distributions, both the left and right children need to + // have it available. For example, if optimizing under HashPartitioning on "var1", we + // need to check that this distribution is available in both child groups. If not, and + // it is available in one group, we can try replicating the other group (below). + // If optimizing the seek side specifically, we allow for a RoundRobin distribution + // which can match a collection with UnknownPartitioning. + optimizeFn(); + } - if (isIndex) { - switch (distrAndProjections._type) { - case DistributionType::HashPartitioning: - case DistributionType::RangePartitioning: { - // Specifically for index intersection, try propagating the requirement on one - // side and replicating the other. - - const auto& leftDistributions = - getPropertyConst<DistributionAvailability>(leftLogicalProps) - .getDistributionSet(); - const auto& rightDistributions = - getPropertyConst<DistributionAvailability>(rightLogicalProps) - .getDistributionSet(); - - if (leftDistributions.count(distrAndProjections) > 0) { - setPropertyOverwrite<DistributionRequirement>(leftPhysProps, - distribRequirement); - setPropertyOverwrite<DistributionRequirement>( - rightPhysProps, DistributionRequirement{DistributionType::Replicated}); - optimizeFn(); - } + if (!isIndex) { + // Nothing more to do for Complete target. The index side needs to be collocated with + // the seek side. + return; + } - if (rightDistributions.count(distrAndProjections) > 0) { - setPropertyOverwrite<DistributionRequirement>( - leftPhysProps, DistributionRequirement{DistributionType::Replicated}); - setPropertyOverwrite<DistributionRequirement>(rightPhysProps, - distribRequirement); - optimizeFn(); - } - break; + // Specifically for index intersection, try propagating the requirement on one + // side and replicating the other. + switch (distrAndProjections._type) { + case DistributionType::HashPartitioning: + case DistributionType::RangePartitioning: { + if (leftDistrOK) { + setPropertyOverwrite<DistributionRequirement>(leftPhysProps, + distribRequirement); + setPropertyOverwrite<DistributionRequirement>( + rightPhysProps, DistributionRequirement{DistributionType::Replicated}); + optimizeFn(); } - default: - break; + if (rightDistrOK) { + setPropertyOverwrite<DistributionRequirement>( + leftPhysProps, DistributionRequirement{DistributionType::Replicated}); + setPropertyOverwrite<DistributionRequirement>(rightPhysProps, + distribRequirement); + optimizeFn(); + } + break; } + + default: + break; } } diff --git a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp index e0015876bd4..c0b3423ab0c 100644 --- a/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp +++ b/src/mongo/db/query/optimizer/cascades/physical_rewriter.cpp @@ -347,13 +347,8 @@ PhysicalRewriter::OptimizeGroupResult PhysicalRewriter::optimizeGroup(const Grou // Enforcement rewrites run just once, and are independent of the logical nodes. if (groupId != _rootGroupId) { // Verify properties can be enforced and add enforcers if necessary. - addEnforcers(groupId, - _metadata, - _ridProjections, - prefixId, - queue._queue, - bestResult._physProps, - logicalProps); + addEnforcers( + groupId, _metadata, _ridProjections, queue._queue, bestResult._physProps, logicalProps); } // Iterate until we perform all logical for the group and physical rewrites for our best plan. diff --git a/src/mongo/db/query/optimizer/opt_phase_manager.cpp b/src/mongo/db/query/optimizer/opt_phase_manager.cpp index 0f893e5d230..4c4fd0deed3 100644 --- a/src/mongo/db/query/optimizer/opt_phase_manager.cpp +++ b/src/mongo/db/query/optimizer/opt_phase_manager.cpp @@ -32,7 +32,6 @@ #include "mongo/db/query/optimizer/cascades/ce_heuristic.h" #include "mongo/db/query/optimizer/cascades/cost_derivation.h" #include "mongo/db/query/optimizer/cascades/logical_props_derivation.h" -#include "mongo/db/query/optimizer/explain.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" #include "mongo/db/query/optimizer/rewrites/path.h" #include "mongo/db/query/optimizer/rewrites/path_lower.h" @@ -95,6 +94,17 @@ OptPhaseManager::OptPhaseManager(OptPhaseManager::PhaseSet phaseSet, } } +static std::string generateFreeVarsAssertMsg(const VariableEnvironment& env) { + std::string result; + for (const auto& name : env.freeVariableNames()) { + if (!result.empty()) { + result += ", "; + } + result += name; + } + return result; +} + template <OptPhase phase, class C> void OptPhaseManager::runStructuralPhase(C instance, VariableEnvironment& env, ABT& input) { if (!hasPhase(phase)) { @@ -108,7 +118,9 @@ void OptPhaseManager::runStructuralPhase(C instance, VariableEnvironment& env, A !_debugInfo.exceedsIterationLimit(iterationCount)); } - tassert(6808709, "Environment has free variables.", !env.hasFreeVariables()); + if (env.hasFreeVariables()) { + tasserted(6808709, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); + } } template <OptPhase phase1, OptPhase phase2, class C1, class C2> @@ -141,7 +153,9 @@ void OptPhaseManager::runStructuralPhases(C1 instance1, } } - tassert(6808701, "Environment has free variables.", !env.hasFreeVariables()); + if (env.hasFreeVariables()) { + tasserted(6808701, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); + } } void OptPhaseManager::runMemoLogicalRewrite(const OptPhase phase, @@ -179,7 +193,9 @@ void OptPhaseManager::runMemoLogicalRewrite(const OptPhase phase, env.rebuild(input); } - tassert(6808703, "Environment has free variables.", !env.hasFreeVariables()); + if (env.hasFreeVariables()) { + tasserted(6808703, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); + } } void OptPhaseManager::runMemoPhysicalRewrite(const OptPhase phase, @@ -202,7 +218,7 @@ void OptPhaseManager::runMemoPhysicalRewrite(const OptPhase phase, if (_requireRID) { const auto& rootLogicalProps = _memo.getLogicalProps(rootGroupId); tassert(6808705, - "We cannot optain rid for this query.", + "We cannot obtain rid for this query.", hasProperty<IndexingAvailability>(rootLogicalProps)); const auto& scanDefName = @@ -232,8 +248,9 @@ void OptPhaseManager::runMemoPhysicalRewrite(const OptPhase phase, std::tie(input, _nodeToGroupPropsMap) = extractPhysicalPlan(_physicalNodeId, _metadata, _memo); env.rebuild(input); - - tassert(6808707, "Environment has free variables.", !env.hasFreeVariables()); + if (env.hasFreeVariables()) { + tasserted(6808707, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); + } } void OptPhaseManager::runMemoRewritePhases(VariableEnvironment& env, ABT& input) { @@ -263,24 +280,10 @@ void OptPhaseManager::runMemoRewritePhases(VariableEnvironment& env, ABT& input) void OptPhaseManager::optimize(ABT& input) { VariableEnvironment env = VariableEnvironment::build(input); - - std::string freeVariables = ""; if (env.hasFreeVariables()) { - bool first = true; - for (auto& name : env.freeVariableNames()) { - if (first) { - first = false; - } else { - freeVariables += ", "; - } - freeVariables += name; - } + tasserted(6808711, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); } - tassert(6808711, - "Environment has the following free variables: " + freeVariables + ".", - !env.hasFreeVariables()); - const auto sargableCheckFn = [this](const ABT& expr) { return convertExprToPartialSchemaReq(expr, false /*isFilterContext*/, _pathToInterval) .has_value(); @@ -314,7 +317,9 @@ void OptPhaseManager::optimize(ABT& input) { } env.rebuild(input); - tassert(6808710, "Environment has free variables.", !env.hasFreeVariables()); + if (env.hasFreeVariables()) { + tasserted(6808710, "Plan has free variables: " + generateFreeVarsAssertMsg(env)); + } } bool OptPhaseManager::hasPhase(const OptPhase phase) const { diff --git a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp index 9c43d199527..1079e45acfa 100644 --- a/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/physical_rewriter_optimizer_test.cpp @@ -4478,7 +4478,7 @@ TEST(PhysRewriter, IndexPartitioning) { ABT optimized = rootNode; phaseManager.optimize(optimized); - ASSERT_BETWEEN(75, 150, phaseManager.getMemo().getStats()._physPlanExplorationCount); + ASSERT_BETWEEN(75, 125, phaseManager.getMemo().getStats()._physPlanExplorationCount); ASSERT_EXPLAIN_V2( "Root []\n" @@ -4539,6 +4539,10 @@ TEST(PhysRewriter, IndexPartitioning1) { using namespace properties; PrefixId prefixId; + PartialSchemaSelHints hints; + hints.emplace(PartialSchemaKey{"root", make<PathGet>("a", make<PathIdentity>())}, 0.02); + hints.emplace(PartialSchemaKey{"root", make<PathGet>("b", make<PathIdentity>())}, 0.01); + ABT scanNode = make<ScanNode>("root", "c1"); ABT projectionANode = make<EvaluationNode>( @@ -4574,6 +4578,7 @@ TEST(PhysRewriter, IndexPartitioning1) { OptPhase::MemoExplorationPhase, OptPhase::MemoImplementationPhase}, prefixId, + false /*requireRID*/, {{{"c1", createScanDef( {}, @@ -4592,11 +4597,15 @@ TEST(PhysRewriter, IndexPartitioning1) { ConstEval::constFold, {DistributionType::HashPartitioning, makeSeq(makeNonMultikeyIndexPath("c"))})}}, 5 /*numberOfPartitions*/}, + std::make_unique<HintedCE>(std::move(hints)), + std::make_unique<DefaultCosting>(), + {} /*pathToInterval*/, + ConstEval::constFold, {true /*debugMode*/, 2 /*debugLevel*/, DebugInfo::kIterationLimitForTests}); ABT optimized = rootNode; phaseManager.optimize(optimized); - ASSERT_BETWEEN(150, 350, phaseManager.getMemo().getStats()._physPlanExplorationCount); + ASSERT_BETWEEN(125, 175, phaseManager.getMemo().getStats()._physPlanExplorationCount); const BSONObj& result = ExplainGenerator::explainBSONObj(optimized); @@ -4607,15 +4616,17 @@ TEST(PhysRewriter, IndexPartitioning1) { ASSERT_BSON_PATH("\"GroupBy\"", result, "child.child.nodeType"); ASSERT_BSON_PATH("\"HashJoin\"", result, "child.child.child.nodeType"); ASSERT_BSON_PATH("\"Exchange\"", result, "child.child.child.leftChild.nodeType"); - ASSERT_BSON_PATH( - "{ type: \"HashPartitioning\", disableExchanges: false, projections: [ \"pa\" ] }", - result, - "child.child.child.leftChild.distribution"); + ASSERT_BSON_PATH("{ type: \"Replicated\", disableExchanges: false }", + result, + "child.child.child.leftChild.distribution"); ASSERT_BSON_PATH("\"IndexScan\"", result, "child.child.child.leftChild.child.nodeType"); + ASSERT_BSON_PATH("\"index2\"", result, "child.child.child.leftChild.child.indexDefName"); ASSERT_BSON_PATH("\"Union\"", result, "child.child.child.rightChild.nodeType"); ASSERT_BSON_PATH("\"Evaluation\"", result, "child.child.child.rightChild.children.0.nodeType"); ASSERT_BSON_PATH( "\"IndexScan\"", result, "child.child.child.rightChild.children.0.child.nodeType"); + ASSERT_BSON_PATH( + "\"index1\"", result, "child.child.child.rightChild.children.0.child.indexDefName"); } TEST(PhysRewriter, LocalGlobalAgg) { diff --git a/src/mongo/db/query/optimizer/reference_tracker.cpp b/src/mongo/db/query/optimizer/reference_tracker.cpp index c4b70b51a22..a21aa06b25d 100644 --- a/src/mongo/db/query/optimizer/reference_tracker.cpp +++ b/src/mongo/db/query/optimizer/reference_tracker.cpp @@ -69,18 +69,22 @@ struct CollectedInfo { /** * This is a destructive merge, the 'other' will be siphoned out. */ + template <bool resolveFreeVarsWithOther = true> void merge(CollectedInfo&& other) { - // Incoming (other) info has some definitions. So let's try to resolve our free variables. - if (!other.defs.empty() && !freeVars.empty()) { - for (auto&& [name, def] : other.defs) { - resolveFreeVars(name, def); + if constexpr (resolveFreeVarsWithOther) { + // Incoming (other) info has some definitions. So let's try to resolve our free + // variables. + if (!other.defs.empty() && !freeVars.empty()) { + for (auto&& [name, def] : other.defs) { + resolveFreeVars(name, def); + } } - } - // We have some definitions so let try to resolve other's free variables. - if (!defs.empty() && !other.freeVars.empty()) { - for (auto&& [name, def] : defs) { - other.resolveFreeVars(name, def); + // We have some definitions so let try to resolve other's free variables. + if (!defs.empty() && !other.freeVars.empty()) { + for (auto&& [name, def] : defs) { + other.resolveFreeVars(name, def); + } } } @@ -273,6 +277,7 @@ struct Collector { CollectedInfo transport(const ABT&, const T& op, Ts&&... ts) { // The default behavior resolves free variables, merges known definitions and propagates // them up unmodified. + // TODO: SERVER-70880: Remove default ABT type handler in the reference tracker. CollectedInfo result{}; (result.merge(std::forward<Ts>(ts)), ...); @@ -493,10 +498,22 @@ struct Collector { } } - // The correlated projections will be resolved automatically by the merging. We need to - // propagate the right child projections here, since these may be useful to ancestor ndoes. result.merge(std::move(leftChildResult)); - result.merge(std::move(rightChildResult)); + + if (!result.defs.empty() && !rightChildResult.freeVars.empty()) { + // Manually resolve free variables in the right child using the correlated variables + // from the left child. + for (auto&& [name, def] : result.defs) { + if (correlatedProjNames.count(name) > 0) { + rightChildResult.resolveFreeVars(name, def); + } + } + } + + // Do not resolve further free variables. We also need to propagate the right child + // projections here, since these may be useful to ancestor nodes. + result.merge<false /*resolveFreeVarsWithOther*/>(std::move(rightChildResult)); + result.mergeNoDefs(std::move(filterResult)); result.nodeDefs[&binaryJoinNode] = result.defs; @@ -505,6 +522,40 @@ struct Collector { } CollectedInfo transport(const ABT& n, + const HashJoinNode& hashJoinNode, + CollectedInfo leftChildResult, + CollectedInfo rightChildResult, + CollectedInfo refsResult) { + CollectedInfo result{}; + + result.merge(std::move(leftChildResult)); + // Do not resolve further free variables. + result.merge<false /*resolveFreeVarsWithOther*/>(std::move(rightChildResult)); + result.mergeNoDefs(std::move(refsResult)); + + result.nodeDefs[&hashJoinNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, + const MergeJoinNode& mergeJoinNode, + CollectedInfo leftChildResult, + CollectedInfo rightChildResult, + CollectedInfo refsResult) { + CollectedInfo result{}; + + result.merge(std::move(leftChildResult)); + // Do not resolve further free variables. + result.merge<false /*resolveFreeVarsWithOther*/>(std::move(rightChildResult)); + result.mergeNoDefs(std::move(refsResult)); + + result.nodeDefs[&mergeJoinNode] = result.defs; + + return result; + } + + CollectedInfo transport(const ABT& n, const UnionNode& unionNode, std::vector<CollectedInfo> childResults, CollectedInfo bindResult, diff --git a/src/mongo/db/query/optimizer/reference_tracker_test.cpp b/src/mongo/db/query/optimizer/reference_tracker_test.cpp index 98c34c50bc5..f3efa486be4 100644 --- a/src/mongo/db/query/optimizer/reference_tracker_test.cpp +++ b/src/mongo/db/query/optimizer/reference_tracker_test.cpp @@ -28,11 +28,11 @@ */ #include "mongo/db/query/optimizer/reference_tracker.h" -#include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" + namespace mongo::optimizer { namespace { @@ -360,7 +360,13 @@ TEST(ReferenceTrackerTest, FreeVariablesBinaryJoin) { "evalProjB", make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("scanProj1")), std::move(scanNodeLeft))); + ABT evalNodeLeft1 = make<EvaluationNode>( + "evalProjA1", + make<EvalPath>(make<PathGet>("a1", make<PathIdentity>()), make<Variable>("scanProj1")), + std::move(evalNodeLeft)); + // "evalProjA" needs to come from the left child and IS set to be correlated in the binary join + // below. ABT evalNodeRight = make<EvaluationNode>( "evalProjC", make<EvalPath>(make<PathGet>("c", make<PathIdentity>()), make<Variable>("scanProj2")), @@ -368,24 +374,97 @@ TEST(ReferenceTrackerTest, FreeVariablesBinaryJoin) { make<EvalPath>(make<PathIdentity>(), make<Variable>("evalProjA")), std::move(scanNodeRight))); + // "evalProjA1" needs to come from the left child and IS NOT set to be correlated in the binary + // join below. + ABT evalNodeRight1 = make<EvaluationNode>( + "evalProjC1", + make<EvalPath>(make<PathGet>("c1", make<PathIdentity>()), make<Variable>("evalProjA1")), + std::move(evalNodeRight)); + + ABT joinNode = make<BinaryJoinNode>( JoinType::Inner, - ProjectionNameSet{}, + ProjectionNameSet{"evalProjA"}, make<BinaryOp>(Operations::Eq, make<Variable>("evalProjA"), make<Variable>("evalProjC")), - std::move(evalNodeLeft), - std::move(evalNodeRight)); + std::move(evalNodeLeft1), + std::move(evalNodeRight1)); - // Check that the binary join resolves the free variables in the filter and the right child. + // Check that the binary join resolves "evalProjA" but not "evalProjA1" in the right child and + // the filter. auto env = VariableEnvironment::build(joinNode); - ASSERT(!env.hasFreeVariables()); + ASSERT_EQ(env.freeOccurences("evalProjA1"), 1); // Check that the binary join node propagates up left and right projections. auto binaryProjs = env.getProjections(joinNode.ref()); - ProjectionNameSet expectedBinaryProjSet = { - "evalProjA", "evalProjB", "scanProj1", "evalProjC", "evalProjD", "scanProj2"}; + ProjectionNameSet expectedBinaryProjSet{"evalProjA", + "evalProjA1", + "evalProjB", + "scanProj1", + "evalProjC", + "evalProjC1", + "evalProjD", + "scanProj2"}; ASSERT(expectedBinaryProjSet == binaryProjs); } +TEST(ReferenceTrackerTest, HashJoin) { + ABT scanNodeLeft = make<ScanNode>("scanProj1", "coll"); + ABT scanNodeRight = make<ScanNode>("scanProj2", "coll"); + + ABT evalNodeLeft = make<EvaluationNode>( + "evalProjA", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("scanProj1")), + std::move(scanNodeLeft)); + + ABT evalNodeRight = make<EvaluationNode>( + "evalProjB", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("scanProj1")), + std::move(scanNodeRight)); + + ABT joinNode = make<HashJoinNode>(JoinType::Inner, + ProjectionNameVector{"evalProjA"}, + ProjectionNameVector{"evalProjB"}, + std::move(evalNodeLeft), + std::move(evalNodeRight)); + + auto env = VariableEnvironment::build(joinNode); + ASSERT_EQ(env.freeOccurences("scanProj1"), 1); + + // Check that we propagate left and right projections. + auto joinProjs = env.getProjections(joinNode.ref()); + ProjectionNameSet expectedProjSet{"evalProjA", "evalProjB", "scanProj1", "scanProj2"}; + ASSERT(expectedProjSet == joinProjs); +} + +TEST(ReferenceTrackerTest, MergeJoin) { + ABT scanNodeLeft = make<ScanNode>("scanProj1", "coll"); + ABT scanNodeRight = make<ScanNode>("scanProj2", "coll"); + + ABT evalNodeLeft = make<EvaluationNode>( + "evalProjA", + make<EvalPath>(make<PathGet>("a", make<PathIdentity>()), make<Variable>("scanProj1")), + std::move(scanNodeLeft)); + + ABT evalNodeRight = make<EvaluationNode>( + "evalProjB", + make<EvalPath>(make<PathGet>("b", make<PathIdentity>()), make<Variable>("scanProj1")), + std::move(scanNodeRight)); + + ABT joinNode = make<MergeJoinNode>(ProjectionNameVector{"evalProjA"}, + ProjectionNameVector{"evalProjB"}, + std::vector<CollationOp>{CollationOp::Ascending}, + std::move(evalNodeLeft), + std::move(evalNodeRight)); + + auto env = VariableEnvironment::build(joinNode); + ASSERT_EQ(env.freeOccurences("scanProj1"), 1); + + // Check that we propagate left and right projections. + auto joinProjs = env.getProjections(joinNode.ref()); + ProjectionNameSet expectedProjSet{"evalProjA", "evalProjB", "scanProj1", "scanProj2"}; + ASSERT(expectedProjSet == joinProjs); +} + TEST(ReferenceTrackerTest, SingleVarNotLastRef) { // There are no last refs in an ABT that doesn't "finalize" any last refs. ABT justVar = make<Variable>("var"); |