summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
authorHana Pearlman <hana.pearlman@mongodb.com>2020-11-23 20:23:15 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-12-04 18:02:19 +0000
commit8284eb2c2eedf795293b5ebc349e468e8666073c (patch)
tree9ed9a73e39f125deb1b1ffd7e0ce7b154c64e192 /src/mongo/db
parent46befe17d5d41b52808af01ef680a3405e113792 (diff)
downloadmongo-8284eb2c2eedf795293b5ebc349e468e8666073c.tar.gz
SERVER-12573 Extend $ifNull to accept more than two arguments
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/pipeline/expression.cpp19
-rw-r--r--src/mongo/db/pipeline/expression.h5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp55
3 files changed, 58 insertions, 21 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 6d95102b5e4..06481c5d953 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -2956,13 +2956,20 @@ const char* ExpressionMultiply::getOpName() const {
/* ----------------------- ExpressionIfNull ---------------------------- */
-Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
- Value pLeft(_children[0]->evaluate(root, variables));
- if (!pLeft.nullish())
- return pLeft;
+void ExpressionIfNull::validateArguments(const ExpressionVector& args) const {
+ uassert(1257300,
+ str::stream() << "$ifNull needs at least two arguments, had: " << args.size(),
+ args.size() >= 2);
+}
- Value pRight(_children[1]->evaluate(root, variables));
- return pRight;
+Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
+ const size_t n = _children.size();
+ for (size_t i = 0; i < n; ++i) {
+ Value pValue(_children[i]->evaluate(root, variables));
+ if (!pValue.nullish() || i == n - 1)
+ return pValue;
+ }
+ return Value();
}
REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse);
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 8c326db6360..cbcf74dd618 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -1601,13 +1601,14 @@ public:
};
-class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> {
+class ExpressionIfNull final : public ExpressionVariadic<ExpressionIfNull> {
public:
explicit ExpressionIfNull(ExpressionContext* const expCtx)
- : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {}
+ : ExpressionVariadic<ExpressionIfNull>(expCtx) {}
Value evaluate(const Document& root, Variables* variables) const final;
const char* getOpName() const final;
+ void validateArguments(const ExpressionVector& args) const final;
void acceptVisitor(ExpressionVisitor* visitor) final {
return visitor->visit(this);
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 938982b2fd2..ddecb47bb38 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -376,7 +376,9 @@ public:
void visit(ExpressionFieldPath* expr) final {}
void visit(ExpressionFilter* expr) final {}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -545,7 +547,9 @@ public:
_context->evalStack.emplaceFrame(EvalStage{});
}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -1813,21 +1817,46 @@ public:
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr)));
}
void visit(ExpressionIfNull* expr) final {
- _context->ensureArity(2);
+ auto numChildren = expr->getChildren().size();
+ invariant(numChildren >= 2);
- auto replacementIfNull = _context->popExpr();
- auto input = _context->popExpr();
+ std::vector<EvalExprStagePair> branches;
+ branches.reserve(numChildren);
+ for (size_t i = 0; i < numChildren; ++i) {
+ auto [expr, stage] = _context->popFrame();
+ branches.emplace_back(std::move(expr), std::move(stage));
+ }
+ std::reverse(branches.begin(), branches.end());
- auto frameId = _context->frameIdGenerator->generate();
- auto binds = sbe::makeEs(std::move(input));
- sbe::EVariable inputRef(frameId, 0);
+ // Prepare to create limit-1/union with N branches (where N is the number of operands). Each
+ // branch will be evaluated from left to right until one of the branches produces a value.
+ auto branchFn = [](EvalExpr evalExpr,
+ EvalStage stage,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ auto slot = slotIdGenerator->generate();
+ stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr());
+
+ // Create a FilterStage for each branch (except the last one). If a branch's filter
+ // condition is true, it will "short-circuit" the evaluation process. For ifNull,
+ // short-circuiting should happen if the current variable is not null or missing.
+ auto filterExpr = makeNot(generateNullOrMissing(slot));
+ auto filterStage =
+ makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId);
+
+ // Set the current expression as the output to be returned if short-circuiting occurs.
+ return std::make_pair(slot, std::move(filterStage));
+ };
- // If input is null or missing, return replacement expression. Otherwise, return input.
- auto ifNullExpr = sbe::makeE<sbe::EIf>(
- generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone());
+ auto [resultExpr, opStage] = generateSingleResultUnion(
+ std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator);
- _context->pushExpr(
- sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr)));
+ auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(),
+ std::move(opStage),
+ _context->planNodeId,
+ _context->getLexicalEnvironment());
+
+ _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage));
}
void visit(ExpressionIn* expr) final {
unsupportedExpression(expr->getOpName());