summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorJustin Seyster <justin.seyster@mongodb.com>2020-08-25 13:33:40 -0400
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-08-26 20:08:43 +0000
commitb24d4b2a96ef13bc233fd9cb9ecbefb905ee0ca8 (patch)
tree120734941503340d0e8d35708f8a698f5636ad3c /src/mongo/db
parent019fa81486e81bb6d0de3fbedd480e40b6929e23 (diff)
downloadmongo-b24d4b2a96ef13bc233fd9cb9ecbefb905ee0ca8.tar.gz
SERVER-49342 Support for $switch and $cond in SBE
Co-authored-by: Hirday Gupta <hirday.gupta@mongodb.com>
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/exec/sbe/stages/branch.h2
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp325
2 files changed, 253 insertions, 74 deletions
diff --git a/src/mongo/db/exec/sbe/stages/branch.h b/src/mongo/db/exec/sbe/stages/branch.h
index 9d9005e6d9f..544906c6af0 100644
--- a/src/mongo/db/exec/sbe/stages/branch.h
+++ b/src/mongo/db/exec/sbe/stages/branch.h
@@ -36,7 +36,7 @@
namespace mongo::sbe {
/**
* This stage delivers results from either 'then' or 'else' branch depending on the value of the
- * 'filer' expression as evaluated during the open() call.
+ * 'filter' expression as evaluated during the open() call.
*/
class BranchStage final : public PlanStage {
public:
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 9639c8bacd4..359d06c7f74 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -32,6 +32,7 @@
#include "mongo/db/query/sbe_stage_builder_expression.h"
#include "mongo/db/query/util/make_data_structure.h"
+#include "mongo/db/exec/sbe/stages/branch.h"
#include "mongo/db/exec/sbe/stages/co_scan.h"
#include "mongo/db/exec/sbe/stages/filter.h"
#include "mongo/db/exec/sbe/stages/limit_skip.h"
@@ -64,11 +65,14 @@ std::pair<sbe::value::TypeTags, sbe::value::Value> convertFrom(Value val) {
struct ExpressionVisitorContext {
struct VarsFrame {
std::deque<Variables::Id> variablesToBind;
- sbe::value::SlotMap<std::unique_ptr<sbe::EExpression>> boundVariables;
+
+ // Slots that have been used to bind $let variables. This list is necessary to know which
+ // slots to remove from the environment when the $let goes out of scope.
+ std::set<sbe::value::SlotId> slotsForLetVariables;
template <class... Args>
VarsFrame(Args&&... args)
- : variablesToBind{std::forward<Args>(args)...}, boundVariables{} {}
+ : variablesToBind{std::forward<Args>(args)...}, slotsForLetVariables{} {}
};
struct LogicalExpressionEvalFrame {
@@ -79,6 +83,15 @@ struct ExpressionVisitorContext {
std::vector<std::pair<sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>> branches;
+ // When traversing the branches of a $switch expression, the in-visitor will see each branch
+ // of the $switch _twice_: once for the "case" part of the branch (the condition) and once
+ // for the "then" part (the expression that the $switch will evaluate to if the condition
+ // evaluates to true). During the first visit, we temporarily store the condition here so
+ // that it is available to use during the second visit, which constructs the completed
+ // EExpression for the branch and stores it in the 'branches' vector.
+ boost::optional<std::pair<sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>>>
+ switchBranchConditionalStage;
+
LogicalExpressionEvalFrame(std::unique_ptr<sbe::PlanStage> traverseStage,
const sbe::value::SlotVector& relevantSlots,
sbe::value::SlotId nextBranchResultSlot)
@@ -104,6 +117,69 @@ struct ExpressionVisitorContext {
invariant(exprs.size() >= arity);
}
+ /**
+ * Construct a UnionStage from the PlanStages in the 'branches' list and attach it to the inner
+ * side of a LoopJoinStage, which iterates over each branch of the UnionStage until it finds one
+ * that returns a result. Iteration ceases after the first branch that returns a result so that
+ * the remaining branches are "short circuited" and we don't do unnecessary work for for MQL
+ * expressions that are not evaluated.
+ */
+ void generateSubTreeForSelectiveExecution() {
+ auto& logicalExpressionEvalFrame = logicalExpressionEvalFrameStack.top();
+
+ std::vector<sbe::value::SlotVector> branchSlots;
+ std::vector<std::unique_ptr<sbe::PlanStage>> branchStages;
+ for (auto&& [slot, stage] : logicalExpressionEvalFrame.branches) {
+ branchSlots.push_back(sbe::makeSV(slot));
+ branchStages.push_back(std::move(stage));
+ }
+
+ auto unionStageResultSlot = slotIdGenerator->generate();
+ auto unionOfBranches = sbe::makeS<sbe::UnionStage>(
+ std::move(branchStages), std::move(branchSlots), sbe::makeSV(unionStageResultSlot));
+
+ // Restore 'relevantSlots' to the way it was before we started translating the logic
+ // operator.
+ *relevantSlots = std::move(logicalExpressionEvalFrame.savedRelevantSlots);
+
+ // Get a list of slots that are used by $let expressions. These slots need to be available
+ // to the inner side of the LoopJoinStage, in case any of the branches want to reference one
+ // of the variables bound by the $let.
+ sbe::value::SlotVector letBindings;
+ for (auto&& [_, slot] : environment) {
+ letBindings.push_back(slot);
+ }
+
+ // The LoopJoinStage we are creating here will not expose any of the slots from its outer
+ // side except for the ones we explicity ask for. For that reason, we maintain the
+ // 'relevantSlots' list of slots that may still be referenced above this stage. All of the
+ // slots in 'letBindings' are relevant by this definition, but we track them separately,
+ // which is why we need to add them in now.
+ auto relevantSlotsWithLetBindings(*relevantSlots);
+ relevantSlotsWithLetBindings.insert(
+ relevantSlotsWithLetBindings.end(), letBindings.begin(), letBindings.end());
+
+ // Put the union into a nested loop. The inner side of the nested loop will execute exactly
+ // once, trying each branch of the union until one of them short circuits or until it
+ // reaches the end. This process also restores the old 'traverseStage' value from before we
+ // started translating the logic operator, by placing it below the new nested loop stage.
+ auto stage = sbe::makeS<sbe::LoopJoinStage>(
+ std::move(logicalExpressionEvalFrame.savedTraverseStage),
+ sbe::makeS<sbe::LimitSkipStage>(std::move(unionOfBranches), 1, boost::none),
+ std::move(relevantSlotsWithLetBindings),
+ std::move(letBindings),
+ nullptr /* predicate */);
+
+ // We've already restored all necessary state from the top 'logicalExpressionEvalFrameStack'
+ // entry, so we are done with it.
+ logicalExpressionEvalFrameStack.pop();
+
+ // The final result of the logic operator is stored in the 'branchResultSlot' slot.
+ relevantSlots->push_back(unionStageResultSlot);
+ pushExpr(sbe::makeE<sbe::EVariable>(unionStageResultSlot), std::move(stage));
+ }
+
+
std::unique_ptr<sbe::EExpression> popExpr() {
auto expr = std::move(exprs.top());
exprs.pop();
@@ -154,6 +230,26 @@ struct ExpressionVisitorContext {
}
/**
+ * Temporarily reset 'traverseStage' and 'relevantSlots' so they are prepared for translating a
+ * $switch branch. They can be restored later using the 'logicalExpressionEvalFrameStack' top
+ * entry. Once it is fully constructed, this branch will evaluate to the "then" part of the
+ * branch if the condition is true or EOF otherwise. As with $and/$or branches (refer to the
+ * comment describing 'prepareToTranslateShortCircuitingBranch()'), these branches will become
+ * part of a UnionStage that executes the branches in turn until one yields a value.
+ */
+ void prepareToTranslateSwitchBranch(sbe::value::SlotId branchResultSlot) {
+ invariant(!logicalExpressionEvalFrameStack.empty());
+ logicalExpressionEvalFrameStack.top().nextBranchResultSlot = branchResultSlot;
+
+ traverseStage =
+ sbe::makeS<sbe::LimitSkipStage>(sbe::makeS<sbe::CoScanStage>(), 1, boost::none);
+
+ // Slots created in a previous branch for this $switch are not accessible to any stages in
+ // this new branch, so we clear them from the 'relevantSlots' list.
+ *relevantSlots = logicalExpressionEvalFrameStack.top().savedRelevantSlots;
+ }
+
+ /**
* This does the same thing as 'prepareToTranslateShortCircuitingBranch' but is intended to the
* last branch in an $and/$or, which cannot short circuit.
*/
@@ -177,8 +273,14 @@ struct ExpressionVisitorContext {
sbe::value::FrameIdGenerator* frameIdGenerator;
sbe::value::SlotId rootSlot;
std::stack<std::unique_ptr<sbe::EExpression>> exprs;
+
+ // The lexical environment for the expression being traversed. A variable reference takes the
+ // form "$$variable_name" in MQL's concrete syntax and gets transformed into a numeric
+ // identifier (Variables::Id) in the AST. During this translation, we directly translate any
+ // such variable to an SBE slot using this mapping.
std::map<Variables::Id, sbe::value::SlotId> environment;
std::stack<VarsFrame> varsFrameStack;
+
std::stack<LogicalExpressionEvalFrame> logicalExpressionEvalFrameStack;
// See the comment above the generateExpression() declaration for an explanation of the
// 'relevantSlots' list.
@@ -320,7 +422,9 @@ public:
void visit(ExpressionCompare* expr) final {}
void visit(ExpressionConcat* expr) final {}
void visit(ExpressionConcatArrays* expr) final {}
- void visit(ExpressionCond* expr) final {}
+ void visit(ExpressionCond* expr) final {
+ visitConditionalExpression(expr);
+ }
void visit(ExpressionDateFromString* expr) final {}
void visit(ExpressionDateFromParts* expr) final {}
void visit(ExpressionDateToParts* expr) final {}
@@ -376,7 +480,9 @@ public:
void visit(ExpressionBinarySize* expr) final {}
void visit(ExpressionStrLenCP* expr) final {}
void visit(ExpressionSubtract* expr) final {}
- void visit(ExpressionSwitch* expr) final {}
+ void visit(ExpressionSwitch* expr) final {
+ visitConditionalExpression(expr);
+ }
void visit(ExpressionToLower* expr) final {}
void visit(ExpressionToUpper* expr) final {}
void visit(ExpressionTrim* expr) final {}
@@ -448,6 +554,18 @@ private:
_context->prepareToTranslateShortCircuitingBranch(logicOp, branchResultSlot);
}
+ /**
+ * Handle $switch and $cond, which have different syntax but are structurally identical in the
+ * AST.
+ */
+ void visitConditionalExpression(Expression* expr) {
+ auto branchResultSlot = _context->slotIdGenerator->generate();
+ _context->logicalExpressionEvalFrameStack.emplace(
+ std::move(_context->traverseStage), *_context->relevantSlots, branchResultSlot);
+
+ _context->prepareToTranslateSwitchBranch(branchResultSlot);
+ }
+
ExpressionVisitorContext* _context;
};
@@ -475,7 +593,9 @@ public:
void visit(ExpressionCompare* expr) final {}
void visit(ExpressionConcat* expr) final {}
void visit(ExpressionConcatArrays* expr) final {}
- void visit(ExpressionCond* expr) final {}
+ void visit(ExpressionCond* expr) final {
+ visitConditionalExpression(expr);
+ }
void visit(ExpressionDateFromString* expr) final {}
void visit(ExpressionDateFromParts* expr) final {}
void visit(ExpressionDateToParts* expr) final {}
@@ -507,24 +627,15 @@ public:
// We create two bindings. First, the initializer result is bound to a slot when this
// ProjectStage executes.
auto slotToBind = _context->slotIdGenerator->generate();
- invariant(currentFrame.boundVariables.find(slotToBind) ==
- currentFrame.boundVariables.end());
- _context->ensureArity(1);
- currentFrame.boundVariables.insert(std::make_pair(slotToBind, _context->popExpr()));
+ _context->traverseStage = sbe::makeProjectStage(
+ std::move(_context->traverseStage), slotToBind, _context->popExpr());
+ currentFrame.slotsForLetVariables.insert(slotToBind);
// Second, we bind this variables AST-level name (with type Variable::Id) to the SlotId that
// will be used for compilation and execution. Once this "stage builder" finishes, these
// Variable::Id bindings will no longer be relevant.
invariant(_context->environment.find(varToBind) == _context->environment.end());
_context->environment.insert({varToBind, slotToBind});
-
- if (currentFrame.variablesToBind.empty()) {
- // We have traversed every variable initializer, and this is the last "infix" visit for
- // this $let. Add the the ProjectStage that will perform the actual binding now, so that
- // it executes before the "in" portion of the $let statement does.
- _context->traverseStage = sbe::makeS<sbe::ProjectStage>(
- std::move(_context->traverseStage), std::move(currentFrame.boundVariables));
- }
}
void visit(ExpressionLn* expr) final {}
void visit(ExpressionLog* expr) final {}
@@ -562,7 +673,9 @@ public:
void visit(ExpressionBinarySize* expr) final {}
void visit(ExpressionStrLenCP* expr) final {}
void visit(ExpressionSubtract* expr) final {}
- void visit(ExpressionSwitch* expr) final {}
+ void visit(ExpressionSwitch* expr) final {
+ visitConditionalExpression(expr);
+ }
void visit(ExpressionToLower* expr) final {}
void visit(ExpressionToUpper* expr) final {}
void visit(ExpressionTrim* expr) final {}
@@ -657,6 +770,90 @@ private:
}
}
+ /**
+ * Handle $switch and $cond, which have different syntax but are structurally identical in the
+ * AST.
+ */
+ void visitConditionalExpression(Expression* expr) {
+ invariant(_context->logicalExpressionEvalFrameStack.size() > 0);
+
+ auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top();
+ auto& switchBranchConditionalStage =
+ logicalExpressionEvalFrame.switchBranchConditionalStage;
+
+ if (switchBranchConditionalStage == boost::none) {
+ // Here, _context->popExpr() represents the $switch branch's "case" child.
+ auto frameId = _context->frameIdGenerator->generate();
+ auto branchExpr = generateExpressionForLogicBranch(sbe::EVariable{frameId, 0});
+ auto conditionExpr = sbe::makeE<sbe::ELocalBind>(
+ frameId, sbe::makeEs(_context->popExpr()), std::move(branchExpr));
+
+ auto conditionalEvalStage =
+ sbe::makeProjectStage(std::move(_context->traverseStage),
+ logicalExpressionEvalFrame.nextBranchResultSlot,
+ std::move(conditionExpr));
+
+ // Store this case eval stage for use when visiting the $switch branch's "then" child.
+ switchBranchConditionalStage.emplace(logicalExpressionEvalFrame.nextBranchResultSlot,
+ std::move(conditionalEvalStage));
+ } else {
+ // Here, _context->popExpr() represents the $switch branch's "then" child.
+
+ // Get the "case" child to form the outer part of the Loop Join.
+ auto [conditionalEvalStageSlot, conditionalEvalStage] =
+ std::move(*switchBranchConditionalStage);
+ switchBranchConditionalStage = boost::none;
+
+ // Create the "then" child (a BranchStage) to form the inner nlj stage.
+ auto branchStageResultSlot = logicalExpressionEvalFrame.nextBranchResultSlot;
+ auto thenStageResultSlot = _context->slotIdGenerator->generate();
+ auto unusedElseStageResultSlot = _context->slotIdGenerator->generate();
+
+ // Construct a BranchStage tree that will bind the evaluated "then" expression if the
+ // "case" expression evalautes to true and will EOF otherwise.
+ auto branchStage = sbe::makeS<sbe::BranchStage>(
+ sbe::makeProjectStage(
+ std::move(_context->traverseStage), thenStageResultSlot, _context->popExpr()),
+ sbe::makeS<sbe::LimitSkipStage>(
+ sbe::makeProjectStage(
+ sbe::makeS<sbe::CoScanStage>(),
+ unusedElseStageResultSlot,
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Nothing, 0)),
+ 0,
+ boost::none),
+ sbe::makeE<sbe::EVariable>(conditionalEvalStageSlot),
+ sbe::makeSV(thenStageResultSlot),
+ sbe::makeSV(unusedElseStageResultSlot),
+ sbe::makeSV(branchStageResultSlot));
+
+ // Get a list of slots that are used by $let expressions. These slots need to be
+ // available to the inner side of the LoopJoinStage, in case any of the branches want to
+ // reference one of the variables bound by the $let.
+ sbe::value::SlotVector outerCorrelated;
+ for (auto&& [_, slot] : _context->environment) {
+ outerCorrelated.push_back(slot);
+ }
+
+ // The true/false result of the condition, which is evaluated in the outer side of the
+ // LoopJoinStage, must be available to the inner side.
+ outerCorrelated.push_back(conditionalEvalStageSlot);
+
+ // Create a LoopJoinStage that will evaluate its outer child exactly once, to compute
+ // the true/false result of the branch condition, and then execute its inner child
+ // with the result of that condition bound to a correlated slot.
+ auto loopJoinStage = sbe::makeS<sbe::LoopJoinStage>(std::move(conditionalEvalStage),
+ std::move(branchStage),
+ outerCorrelated,
+ outerCorrelated,
+ nullptr /* predicate */);
+
+ logicalExpressionEvalFrame.branches.emplace_back(std::make_pair(
+ logicalExpressionEvalFrame.nextBranchResultSlot, std::move(loopJoinStage)));
+ }
+
+ _context->prepareToTranslateSwitchBranch(_context->slotIdGenerator->generate());
+ }
+
ExpressionVisitorContext* _context;
};
@@ -796,7 +993,7 @@ public:
unsupportedExpression(expr->getOpName());
}
void visit(ExpressionCond* expr) final {
- unsupportedExpression(expr->getOpName());
+ visitConditionalExpression(expr);
}
void visit(ExpressionDateFromString* expr) final {
unsupportedExpression("$dateFromString");
@@ -1192,7 +1389,7 @@ public:
// scope.
auto it = _context->environment.begin();
while (it != _context->environment.end()) {
- if (currentFrame.boundVariables.count(it->first)) {
+ if (currentFrame.slotsForLetVariables.count(it->second)) {
it = _context->environment.erase(it);
} else {
++it;
@@ -1313,7 +1510,7 @@ public:
unsupportedExpression(expr->getOpName());
}
void visit(ExpressionSwitch* expr) final {
- unsupportedExpression("$switch");
+ visitConditionalExpression(expr);
}
void visit(ExpressionToLower* expr) final {
unsupportedExpression(expr->getOpName());
@@ -1506,7 +1703,7 @@ private:
return;
}
- auto& LogicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top();
+ auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top();
// The last branch works differently from the others. It just uses a project stage to
// produce a true or false value for the branch result.
@@ -1518,60 +1715,42 @@ private:
auto lastBranchResultSlot = _context->slotIdGenerator->generate();
auto lastBranch = sbe::makeProjectStage(
std::move(_context->traverseStage), lastBranchResultSlot, std::move(lastBranchExpr));
- LogicalExpressionEvalFrame.branches.emplace_back(
+ logicalExpressionEvalFrame.branches.emplace_back(
std::make_pair(lastBranchResultSlot, std::move(lastBranch)));
- std::vector<sbe::value::SlotVector> branchSlots;
- std::vector<std::unique_ptr<sbe::PlanStage>> branchStages;
- for (auto&& [slot, stage] : LogicalExpressionEvalFrame.branches) {
- branchSlots.push_back(sbe::makeSV(slot));
- branchStages.push_back(std::move(stage));
- }
-
- auto branchResultSlot = _context->slotIdGenerator->generate();
- auto unionOfBranches = sbe::makeS<sbe::UnionStage>(
- std::move(branchStages), std::move(branchSlots), sbe::makeSV(branchResultSlot));
-
- // Restore 'relevantSlots' to the way it was before we started translating the logic
- // operator.
- *_context->relevantSlots = std::move(LogicalExpressionEvalFrame.savedRelevantSlots);
-
- // Get a list of slots that are used by $let expressions. These slots need to be available
- // to the inner side of the LoopJoinStage, in case any of the branches want to reference one
- // of the variables bound by the $let.
- sbe::value::SlotVector letBindings;
- for (auto&& [_, slot] : _context->environment) {
- letBindings.push_back(slot);
- }
-
- // The LoopJoinStage we are creating here will not expose any of the slots from its outer
- // side except for the ones we explicity ask for. For that reason, we maintain the
- // 'relevantSlots' list of slots that may still be referenced above this stage. All of the
- // slots in 'letBindings' are relevant by this definition, but we track them separately,
- // which is why we need to add them in now.
- auto relevantSlotsWithLetBindings(*_context->relevantSlots);
- relevantSlotsWithLetBindings.insert(
- relevantSlotsWithLetBindings.end(), letBindings.begin(), letBindings.end());
-
- // Put the union into a nested loop. The inner side of the nested loop will execute exactly
- // once, trying each branch of the union until one of them short circuits or until it
- // reaches the end. This process also restores the old 'traverseStage' value from before we
- // started translating the logic operator, by placing it below the new nested loop stage.
- auto stage = sbe::makeS<sbe::LoopJoinStage>(
- std::move(LogicalExpressionEvalFrame.savedTraverseStage),
- sbe::makeS<sbe::LimitSkipStage>(std::move(unionOfBranches), 1, boost::none),
- std::move(relevantSlotsWithLetBindings),
- std::move(letBindings),
- nullptr /* predicate */);
-
- // We've already restored all necessary state from the top 'logicalExpressionEvalFrameStack'
- // entry, so we are done with it.
- _context->logicalExpressionEvalFrameStack.pop();
+ _context->generateSubTreeForSelectiveExecution();
+ }
- // The final true/false result of the logic operator is stored in the 'branchResultSlot'
- // slot.
- _context->relevantSlots->push_back(branchResultSlot);
- _context->pushExpr(sbe::makeE<sbe::EVariable>(branchResultSlot), std::move(stage));
+ /**
+ * Handle $switch and $cond, which have different syntax but are structurally identical in the
+ * AST.
+ */
+ void visitConditionalExpression(Expression* expr) {
+ invariant(_context->logicalExpressionEvalFrameStack.size() > 0);
+ auto& logicalExpressionEvalFrame = _context->logicalExpressionEvalFrameStack.top();
+
+ // If this is not boost::none, that would mean the AST somehow had a branch with a "case"
+ // condition but without a "then" value.
+ invariant(logicalExpressionEvalFrame.switchBranchConditionalStage == boost::none);
+
+ // The default case is always the last child in the ExpressionSwitch. If it is unspecified
+ // in the user's query, it is a nullptr. In ExpressionCond, the last child is the "else"
+ // branch, and it is guaranteed not to be nullptr.
+ auto defaultExpr = expr->getChildren().back() == nullptr
+ ? sbe::makeE<sbe::EFail>(ErrorCodes::Error{4934200},
+ "$switch could not find a matching branch for an "
+ "input, and no default was specified.")
+ : this->_context->popExpr();
+
+ auto defaultBranchStage =
+ sbe::makeProjectStage(std::move(_context->traverseStage),
+ logicalExpressionEvalFrame.nextBranchResultSlot,
+ std::move(defaultExpr));
+
+ logicalExpressionEvalFrame.branches.emplace_back(std::make_pair(
+ logicalExpressionEvalFrame.nextBranchResultSlot, std::move(defaultBranchStage)));
+
+ _context->generateSubTreeForSelectiveExecution();
}
void unsupportedExpression(const char* op) const {