diff options
author | Justin Seyster <justin.seyster@mongodb.com> | 2020-08-25 13:33:40 -0400 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-08-26 20:08:43 +0000 |
commit | b24d4b2a96ef13bc233fd9cb9ecbefb905ee0ca8 (patch) | |
tree | 120734941503340d0e8d35708f8a698f5636ad3c /src/mongo/db | |
parent | 019fa81486e81bb6d0de3fbedd480e40b6929e23 (diff) | |
download | mongo-b24d4b2a96ef13bc233fd9cb9ecbefb905ee0ca8.tar.gz |
SERVER-49342 Support for $switch and $cond in SBE
Co-authored-by: Hirday Gupta <hirday.gupta@mongodb.com>
Diffstat (limited to 'src/mongo/db')
-rw-r--r-- | src/mongo/db/exec/sbe/stages/branch.h | 2 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 325 |
2 files changed, 253 insertions, 74 deletions
diff --git a/src/mongo/db/exec/sbe/stages/branch.h b/src/mongo/db/exec/sbe/stages/branch.h index 9d9005e6d9f..544906c6af0 100644 --- a/src/mongo/db/exec/sbe/stages/branch.h +++ b/src/mongo/db/exec/sbe/stages/branch.h @@ -36,7 +36,7 @@ namespace mongo::sbe { /** * This stage delivers results from either 'then' or 'else' branch depending on the value of the - * 'filer' expression as evaluated during the open() call. + * 'filter' expression as evaluated during the open() call. */ class BranchStage final : public PlanStage { public: diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 9639c8bacd4..359d06c7f74 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -32,6 +32,7 @@ #include "mongo/db/query/sbe_stage_builder_expression.h" #include "mongo/db/query/util/make_data_structure.h" +#include "mongo/db/exec/sbe/stages/branch.h" #include "mongo/db/exec/sbe/stages/co_scan.h" #include "mongo/db/exec/sbe/stages/filter.h" #include "mongo/db/exec/sbe/stages/limit_skip.h" @@ -64,11 +65,14 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val) { struct ExpressionVisitorContext { struct VarsFrame { std::deque<Variables::Id> variablesToBind; - sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> boundVariables; + + // Slots that have been used to bind $let variables. This list is necessary to know which + // slots to remove from the environment when the $let goes out of scope. + std::set<sbe::value::SlotId> slotsForLetVariables; template <class... Args> VarsFrame(Args&&... args) - : variablesToBind{std::forward<Args>(args)...}, boundVariables{} {} + : variablesToBind{std::forward<Args>(args)...}, slotsForLetVariables{} {} }; struct LogicalExpressionEvalFrame { @@ -79,6 +83,15 @@ struct ExpressionVisitorContext { std::vector<std::pair<sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>> branches; + // When traversing the branches of a $switch expression, the in-visitor will see each branch + // of the $switch _twice_: once for the "case" part of the branch (the condition) and once + // for the "then" part (the expression that the $switch will evaluate to if the condition + // evaluates to true). During the first visit, we temporarily store the condition here so + // that it is available to use during the second visit, which constructs the completed + // EExpression for the branch and stores it in the 'branches' vector. + boost::optional<std::pair<sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>> + switchBranchConditionalStage; + LogicalExpressionEvalFrame(std::unique_ptr<sbe::PlanStage> traverseStage, const sbe::value::SlotVector& relevantSlots, sbe::value::SlotId nextBranchResultSlot) @@ -104,6 +117,69 @@ struct ExpressionVisitorContext { invariant(exprs.size() >= arity); } + /** + * Construct a UnionStage from the PlanStages in the 'branches' list and attach it to the inner + * side of a LoopJoinStage, which iterates over each branch of the UnionStage until it finds one + * that returns a result. Iteration ceases after the first branch that returns a result so that + * the remaining branches are "short circuited" and we don't do unnecessary work for for MQL + * expressions that are not evaluated. + */ + void generateSubTreeForSelectiveExecution() { + auto& logicalExpressionEvalFrame = logicalExpressionEvalFrameStack.top(); + + std::vector<sbe::value::SlotVector> branchSlots; + std::vector<std::unique_ptr<sbe::PlanStage>> branchStages; + for (auto&& [slot, stage] : logicalExpressionEvalFrame.branches) { + branchSlots.push_back(sbe::makeSV(slot)); + branchStages.push_back(std::move(stage)); + } + + auto unionStageResultSlot = slotIdGenerator->generate(); + auto unionOfBranches = sbe::makeS<sbe::UnionStage>( + std::move(branchStages), std::move(branchSlots), sbe::makeSV(unionStageResultSlot)); + + // Restore 'relevantSlots' to the way it was before we started translating the logic + // operator. + *relevantSlots = std::move(logicalExpressionEvalFrame.savedRelevantSlots); + + // Get a list of slots that are used by $let expressions. These slots need to be available + // to the inner side of the LoopJoinStage, in case any of the branches want to reference one + // of the variables bound by the $let. + sbe::value::SlotVector letBindings; + for (auto&& [_, slot] : environment) { + letBindings.push_back(slot); + } + + // The LoopJoinStage we are creating here will not expose any of the slots from its outer + // side except for the ones we explicity ask for. For that reason, we maintain the + // 'relevantSlots' list of slots that may still be referenced above this stage. All of the + // slots in 'letBindings' are relevant by this definition, but we track them separately, + // which is why we need to add them in now. + auto relevantSlotsWithLetBindings(*relevantSlots); + relevantSlotsWithLetBindings.insert( + relevantSlotsWithLetBindings.end(), letBindings.begin(), letBindings.end()); + + // Put the union into a nested loop. The inner side of the nested loop will execute exactly + // once, trying each branch of the union until one of them short circuits or until it + // reaches the end. This process also restores the old 'traverseStage' value from before we + // started translating the logic operator, by placing it below the new nested loop stage. + auto stage = sbe::makeS<sbe::LoopJoinStage>( + std::move(logicalExpressionEvalFrame.savedTraverseStage), + sbe::makeS<sbe::LimitSkipStage>(std::move(unionOfBranches), 1, boost::none), + std::move(relevantSlotsWithLetBindings), + std::move(letBindings), + nullptr /* predicate */); + + // We've already restored all necessary state from the top 'logicalExpressionEvalFrameStack' + // entry, so we are done with it. + logicalExpressionEvalFrameStack.pop(); + + // The final result of the logic operator is stored in the 'branchResultSlot' slot. + relevantSlots->push_back(unionStageResultSlot); + pushExpr(sbe::makeE<sbe::EVariable>(unionStageResultSlot), std::move(stage)); + } + + std::unique_ptr<sbe::EExpression> popExpr() { auto expr = std::move(exprs.top()); exprs.pop(); @@ -154,6 +230,26 @@ struct ExpressionVisitorContext { } /** + * Temporarily reset 'traverseStage' and 'relevantSlots' so they are prepared for translating a + * $switch branch. They can be restored later using the 'logicalExpressionEvalFrameStack' top + * entry. Once it is fully constructed, this branch will evaluate to the "then" part of the + * branch if the condition is true or EOF otherwise. As with $and/$or branches (refer to the + * comment describing 'prepareToTranslateShortCircuitingBranch()'), these branches will become + * part of a UnionStage that executes the branches in turn until one yields a value. + */ + void prepareToTranslateSwitchBranch(sbe::value::SlotId branchResultSlot) { + invariant(!logicalExpressionEvalFrameStack.empty()); + logicalExpressionEvalFrameStack.top().nextBranchResultSlot = branchResultSlot; + + traverseStage = + sbe::makeS<sbe::LimitSkipStage>(sbe::makeS<sbe::CoScanStage>(), 1, boost::none); + + // Slots created in a previous branch for this $switch are not accessible to any stages in + // this new branch, so we clear them from the 'relevantSlots' list. + *relevantSlots = logicalExpressionEvalFrameStack.top().savedRelevantSlots; + } + + /** * This does the same thing as 'prepareToTranslateShortCircuitingBranch' but is intended to the * last branch in an $and/$or, which cannot short circuit. */ @@ -177,8 +273,14 @@ struct ExpressionVisitorContext { sbe::value::FrameIdGenerator* frameIdGenerator; sbe::value::SlotId rootSlot; std::stack<std::unique_ptr<sbe::EExpression>> exprs; + + // The lexical environment for the expression being traversed. A variable reference takes the + // form "$$variable_name" in MQL's concrete syntax and gets transformed into a numeric + // identifier (Variables::Id) in the AST. During this translation, we directly translate any + // such variable to an SBE slot using this mapping. std::map<Variables::Id, sbe::value::SlotId> environment; std::stack<VarsFrame> varsFrameStack; + std::stack<LogicalExpressionEvalFrame> logicalExpressionEvalFrameStack; // See the comment above the generateExpression() declaration for an explanation of the // 'relevantSlots' list. @@ -320,7 +422,9 @@ public: void visit(ExpressionCompare* expr) final {} void visit(ExpressionConcat* expr) final {} void visit(ExpressionConcatArrays* expr) final {} - void visit(ExpressionCond* expr) final {} + void visit(ExpressionCond* expr) final { + visitConditionalExpression(expr); + } void visit(ExpressionDateFromString* expr) final {} void visit(ExpressionDateFromParts* expr) final {} void visit(ExpressionDateToParts* expr) final {} @@ -376,7 +480,9 @@ public: void visit(ExpressionBinarySize* expr) final {} void visit(ExpressionStrLenCP* expr) final {} void visit(ExpressionSubtract* expr) final {} - void visit(ExpressionSwitch* expr) final {} + void visit(ExpressionSwitch* expr) final { + visitConditionalExpression(expr); + } void visit(ExpressionToLower* expr) final {} void visit(ExpressionToUpper* expr) final {} void visit(ExpressionTrim* expr) final {} @@ -448,6 +554,18 @@ private: _context->prepareToTranslateShortCircuitingBranch(logicOp, branchResultSlot); } + /** + * Handle $switch and $cond, which have different syntax but are structurally identical in the + * AST. + */ + void visitConditionalExpression(Expression* expr) { + auto branchResultSlot = _context->slotIdGenerator->generate(); + _context->logicalExpressionEvalFrameStack.emplace( + std::move(_context->traverseStage), *_context->relevantSlots, branchResultSlot); + + _context->prepareToTranslateSwitchBranch(branchResultSlot); + } + ExpressionVisitorContext* _context; }; @@ -475,7 +593,9 @@ public: void visit(ExpressionCompare* expr) final {} void visit(ExpressionConcat* expr) final {} void visit(ExpressionConcatArrays* expr) final {} - void visit(ExpressionCond* expr) final {} + void visit(ExpressionCond* expr) final { + visitConditionalExpression(expr); + } void visit(ExpressionDateFromString* expr) final {} void visit(ExpressionDateFromParts* expr) final {} void visit(ExpressionDateToParts* expr) final {} @@ -507,24 +627,15 @@ public: // We create two bindings. First, the initializer result is bound to a slot when this // ProjectStage executes. auto slotToBind = _context->slotIdGenerator->generate(); - invariant(currentFrame.boundVariables.find(slotToBind) == - currentFrame.boundVariables.end()); - _context->ensureArity(1); - currentFrame.boundVariables.insert(std::make_pair(slotToBind, _context->popExpr())); + _context->traverseStage = sbe::makeProjectStage( + std::move(_context->traverseStage), slotToBind, _context->popExpr()); + currentFrame.slotsForLetVariables.insert(slotToBind); // Second, we bind this variables AST-level name (with type Variable::Id) to the SlotId that // will be used for compilation and execution. Once this "stage builder" finishes, these // Variable::Id bindings will no longer be relevant. invariant(_context->environment.find(varToBind) == _context->environment.end()); _context->environment.insert({varToBind, slotToBind}); - - if (currentFrame.variablesToBind.empty()) { - // We have traversed every variable initializer, and this is the last "infix" visit for - // this $let. Add the the ProjectStage that will perform the actual binding now, so that - // it executes before the "in" portion of the $let statement does. - _context->traverseStage = sbe::makeS<sbe::ProjectStage>( - std::move(_context->traverseStage), std::move(currentFrame.boundVariables)); - } } void visit(ExpressionLn* expr) final {} void visit(ExpressionLog* expr) final {} @@ -562,7 +673,9 @@ public: void visit(ExpressionBinarySize* expr) final {} void visit(ExpressionStrLenCP* expr) final {} void visit(ExpressionSubtract* expr) final {} - void visit(ExpressionSwitch* expr) final {} + void visit(ExpressionSwitch* expr) final { + visitConditionalExpression(expr); + } void visit(ExpressionToLower* expr) final {} void visit(ExpressionToUpper* expr) final {} void visit(ExpressionTrim* expr) final {} @@ -657,6 +770,90 @@ private: } } + /** + * Handle $switch and $cond, which have different syntax but are structurally identical in the + * AST. + */ + void visitConditionalExpression(Expression* expr) { + invariant(_context->logicalExpressionEvalFrameStack.size() > 0); + + auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top(); + auto& switchBranchConditionalStage = + logicalExpressionEvalFrame.switchBranchConditionalStage; + + if (switchBranchConditionalStage == boost::none) { + // Here, _context->popExpr() represents the $switch branch's "case" child. + auto frameId = _context->frameIdGenerator->generate(); + auto branchExpr = generateExpressionForLogicBranch(sbe::EVariable{frameId, 0}); + auto conditionExpr = sbe::makeE<sbe::ELocalBind>( + frameId, sbe::makeEs(_context->popExpr()), std::move(branchExpr)); + + auto conditionalEvalStage = + sbe::makeProjectStage(std::move(_context->traverseStage), + logicalExpressionEvalFrame.nextBranchResultSlot, + std::move(conditionExpr)); + + // Store this case eval stage for use when visiting the $switch branch's "then" child. + switchBranchConditionalStage.emplace(logicalExpressionEvalFrame.nextBranchResultSlot, + std::move(conditionalEvalStage)); + } else { + // Here, _context->popExpr() represents the $switch branch's "then" child. + + // Get the "case" child to form the outer part of the Loop Join. + auto [conditionalEvalStageSlot, conditionalEvalStage] = + std::move(*switchBranchConditionalStage); + switchBranchConditionalStage = boost::none; + + // Create the "then" child (a BranchStage) to form the inner nlj stage. + auto branchStageResultSlot = logicalExpressionEvalFrame.nextBranchResultSlot; + auto thenStageResultSlot = _context->slotIdGenerator->generate(); + auto unusedElseStageResultSlot = _context->slotIdGenerator->generate(); + + // Construct a BranchStage tree that will bind the evaluated "then" expression if the + // "case" expression evalautes to true and will EOF otherwise. + auto branchStage = sbe::makeS<sbe::BranchStage>( + sbe::makeProjectStage( + std::move(_context->traverseStage), thenStageResultSlot, _context->popExpr()), + sbe::makeS<sbe::LimitSkipStage>( + sbe::makeProjectStage( + sbe::makeS<sbe::CoScanStage>(), + unusedElseStageResultSlot, + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0)), + 0, + boost::none), + sbe::makeE<sbe::EVariable>(conditionalEvalStageSlot), + sbe::makeSV(thenStageResultSlot), + sbe::makeSV(unusedElseStageResultSlot), + sbe::makeSV(branchStageResultSlot)); + + // Get a list of slots that are used by $let expressions. These slots need to be + // available to the inner side of the LoopJoinStage, in case any of the branches want to + // reference one of the variables bound by the $let. + sbe::value::SlotVector outerCorrelated; + for (auto&& [_, slot] : _context->environment) { + outerCorrelated.push_back(slot); + } + + // The true/false result of the condition, which is evaluated in the outer side of the + // LoopJoinStage, must be available to the inner side. + outerCorrelated.push_back(conditionalEvalStageSlot); + + // Create a LoopJoinStage that will evaluate its outer child exactly once, to compute + // the true/false result of the branch condition, and then execute its inner child + // with the result of that condition bound to a correlated slot. + auto loopJoinStage = sbe::makeS<sbe::LoopJoinStage>(std::move(conditionalEvalStage), + std::move(branchStage), + outerCorrelated, + outerCorrelated, + nullptr /* predicate */); + + logicalExpressionEvalFrame.branches.emplace_back(std::make_pair( + logicalExpressionEvalFrame.nextBranchResultSlot, std::move(loopJoinStage))); + } + + _context->prepareToTranslateSwitchBranch(_context->slotIdGenerator->generate()); + } + ExpressionVisitorContext* _context; }; @@ -796,7 +993,7 @@ public: unsupportedExpression(expr->getOpName()); } void visit(ExpressionCond* expr) final { - unsupportedExpression(expr->getOpName()); + visitConditionalExpression(expr); } void visit(ExpressionDateFromString* expr) final { unsupportedExpression("$dateFromString"); @@ -1192,7 +1389,7 @@ public: // scope. auto it = _context->environment.begin(); while (it != _context->environment.end()) { - if (currentFrame.boundVariables.count(it->first)) { + if (currentFrame.slotsForLetVariables.count(it->second)) { it = _context->environment.erase(it); } else { ++it; @@ -1313,7 +1510,7 @@ public: unsupportedExpression(expr->getOpName()); } void visit(ExpressionSwitch* expr) final { - unsupportedExpression("$switch"); + visitConditionalExpression(expr); } void visit(ExpressionToLower* expr) final { unsupportedExpression(expr->getOpName()); @@ -1506,7 +1703,7 @@ private: return; } - auto& LogicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top(); + auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top(); // The last branch works differently from the others. It just uses a project stage to // produce a true or false value for the branch result. @@ -1518,60 +1715,42 @@ private: auto lastBranchResultSlot = _context->slotIdGenerator->generate(); auto lastBranch = sbe::makeProjectStage( std::move(_context->traverseStage), lastBranchResultSlot, std::move(lastBranchExpr)); - LogicalExpressionEvalFrame.branches.emplace_back( + logicalExpressionEvalFrame.branches.emplace_back( std::make_pair(lastBranchResultSlot, std::move(lastBranch))); - std::vector<sbe::value::SlotVector> branchSlots; - std::vector<std::unique_ptr<sbe::PlanStage>> branchStages; - for (auto&& [slot, stage] : LogicalExpressionEvalFrame.branches) { - branchSlots.push_back(sbe::makeSV(slot)); - branchStages.push_back(std::move(stage)); - } - - auto branchResultSlot = _context->slotIdGenerator->generate(); - auto unionOfBranches = sbe::makeS<sbe::UnionStage>( - std::move(branchStages), std::move(branchSlots), sbe::makeSV(branchResultSlot)); - - // Restore 'relevantSlots' to the way it was before we started translating the logic - // operator. - *_context->relevantSlots = std::move(LogicalExpressionEvalFrame.savedRelevantSlots); - - // Get a list of slots that are used by $let expressions. These slots need to be available - // to the inner side of the LoopJoinStage, in case any of the branches want to reference one - // of the variables bound by the $let. - sbe::value::SlotVector letBindings; - for (auto&& [_, slot] : _context->environment) { - letBindings.push_back(slot); - } - - // The LoopJoinStage we are creating here will not expose any of the slots from its outer - // side except for the ones we explicity ask for. For that reason, we maintain the - // 'relevantSlots' list of slots that may still be referenced above this stage. All of the - // slots in 'letBindings' are relevant by this definition, but we track them separately, - // which is why we need to add them in now. - auto relevantSlotsWithLetBindings(*_context->relevantSlots); - relevantSlotsWithLetBindings.insert( - relevantSlotsWithLetBindings.end(), letBindings.begin(), letBindings.end()); - - // Put the union into a nested loop. The inner side of the nested loop will execute exactly - // once, trying each branch of the union until one of them short circuits or until it - // reaches the end. This process also restores the old 'traverseStage' value from before we - // started translating the logic operator, by placing it below the new nested loop stage. - auto stage = sbe::makeS<sbe::LoopJoinStage>( - std::move(LogicalExpressionEvalFrame.savedTraverseStage), - sbe::makeS<sbe::LimitSkipStage>(std::move(unionOfBranches), 1, boost::none), - std::move(relevantSlotsWithLetBindings), - std::move(letBindings), - nullptr /* predicate */); - - // We've already restored all necessary state from the top 'logicalExpressionEvalFrameStack' - // entry, so we are done with it. - _context->logicalExpressionEvalFrameStack.pop(); + _context->generateSubTreeForSelectiveExecution(); + } - // The final true/false result of the logic operator is stored in the 'branchResultSlot' - // slot. - _context->relevantSlots->push_back(branchResultSlot); - _context->pushExpr(sbe::makeE<sbe::EVariable>(branchResultSlot), std::move(stage)); + /** + * Handle $switch and $cond, which have different syntax but are structurally identical in the + * AST. + */ + void visitConditionalExpression(Expression* expr) { + invariant(_context->logicalExpressionEvalFrameStack.size() > 0); + auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top(); + + // If this is not boost::none, that would mean the AST somehow had a branch with a "case" + // condition but without a "then" value. + invariant(logicalExpressionEvalFrame.switchBranchConditionalStage == boost::none); + + // The default case is always the last child in the ExpressionSwitch. If it is unspecified + // in the user's query, it is a nullptr. In ExpressionCond, the last child is the "else" + // branch, and it is guaranteed not to be nullptr. + auto defaultExpr = expr->getChildren().back() == nullptr + ? sbe::makeE<sbe::EFail>(ErrorCodes::Error{4934200}, + "$switch could not find a matching branch for an " + "input, and no default was specified.") + : this->_context->popExpr(); + + auto defaultBranchStage = + sbe::makeProjectStage(std::move(_context->traverseStage), + logicalExpressionEvalFrame.nextBranchResultSlot, + std::move(defaultExpr)); + + logicalExpressionEvalFrame.branches.emplace_back(std::make_pair( + logicalExpressionEvalFrame.nextBranchResultSlot, std::move(defaultBranchStage))); + + _context->generateSubTreeForSelectiveExecution(); } void unsupportedExpression(const char* op) const { |