summaryrefslogtreecommitdiff
path: root/src/mongo/db/query
diff options
context:
space:
mode:
authorHana Pearlman <hana.pearlman@mongodb.com>2020-11-23 20:23:15 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-12-04 18:02:19 +0000
commit8284eb2c2eedf795293b5ebc349e468e8666073c (patch)
tree9ed9a73e39f125deb1b1ffd7e0ce7b154c64e192 /src/mongo/db/query
parent46befe17d5d41b52808af01ef680a3405e113792 (diff)
downloadmongo-8284eb2c2eedf795293b5ebc349e468e8666073c.tar.gz
SERVER-12573 Extend $ifNull to accept more than two arguments
Diffstat (limited to 'src/mongo/db/query')
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp55
1 files changed, 42 insertions, 13 deletions
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 938982b2fd2..ddecb47bb38 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -376,7 +376,9 @@ public:
void visit(ExpressionFieldPath* expr) final {}
void visit(ExpressionFilter* expr) final {}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -545,7 +547,9 @@ public:
_context->evalStack.emplaceFrame(EvalStage{});
}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -1813,21 +1817,46 @@ public:
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr)));
}
void visit(ExpressionIfNull* expr) final {
- _context->ensureArity(2);
+ auto numChildren = expr->getChildren().size();
+ invariant(numChildren >= 2);
- auto replacementIfNull = _context->popExpr();
- auto input = _context->popExpr();
+ std::vector<EvalExprStagePair> branches;
+ branches.reserve(numChildren);
+ for (size_t i = 0; i < numChildren; ++i) {
+ auto [expr, stage] = _context->popFrame();
+ branches.emplace_back(std::move(expr), std::move(stage));
+ }
+ std::reverse(branches.begin(), branches.end());
- auto frameId = _context->frameIdGenerator->generate();
- auto binds = sbe::makeEs(std::move(input));
- sbe::EVariable inputRef(frameId, 0);
+ // Prepare to create limit-1/union with N branches (where N is the number of operands). Each
+ // branch will be evaluated from left to right until one of the branches produces a value.
+ auto branchFn = [](EvalExpr evalExpr,
+ EvalStage stage,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ auto slot = slotIdGenerator->generate();
+ stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr());
+
+ // Create a FilterStage for each branch (except the last one). If a branch's filter
+ // condition is true, it will "short-circuit" the evaluation process. For ifNull,
+ // short-circuiting should happen if the current variable is not null or missing.
+ auto filterExpr = makeNot(generateNullOrMissing(slot));
+ auto filterStage =
+ makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId);
+
+ // Set the current expression as the output to be returned if short-circuiting occurs.
+ return std::make_pair(slot, std::move(filterStage));
+ };
- // If input is null or missing, return replacement expression. Otherwise, return input.
- auto ifNullExpr = sbe::makeE<sbe::EIf>(
- generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone());
+ auto [resultExpr, opStage] = generateSingleResultUnion(
+ std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator);
- _context->pushExpr(
- sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr)));
+ auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(),
+ std::move(opStage),
+ _context->planNodeId,
+ _context->getLexicalEnvironment());
+
+ _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage));
}
void visit(ExpressionIn* expr) final {
unsupportedExpression(expr->getOpName());