diff options
author | Hana Pearlman <hana.pearlman@mongodb.com> | 2020-11-23 20:23:15 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-12-04 18:02:19 +0000 |
commit | 8284eb2c2eedf795293b5ebc349e468e8666073c (patch) | |
tree | 9ed9a73e39f125deb1b1ffd7e0ce7b154c64e192 /src/mongo/db | |
parent | 46befe17d5d41b52808af01ef680a3405e113792 (diff) | |
download | mongo-8284eb2c2eedf795293b5ebc349e468e8666073c.tar.gz |
SERVER-12573 Extend $ifNull to accept more than two arguments
Diffstat (limited to 'src/mongo/db')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 19 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 5 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 55 |
3 files changed, 58 insertions, 21 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 6d95102b5e4..06481c5d953 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2956,13 +2956,20 @@ const char* ExpressionMultiply::getOpName() const { /* ----------------------- ExpressionIfNull ---------------------------- */ -Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const { - Value pLeft(_children[0]->evaluate(root, variables)); - if (!pLeft.nullish()) - return pLeft; +void ExpressionIfNull::validateArguments(const ExpressionVector& args) const { + uassert(1257300, + str::stream() << "$ifNull needs at least two arguments, had: " << args.size(), + args.size() >= 2); +} - Value pRight(_children[1]->evaluate(root, variables)); - return pRight; +Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const { + const size_t n = _children.size(); + for (size_t i = 0; i < n; ++i) { + Value pValue(_children[i]->evaluate(root, variables)); + if (!pValue.nullish() || i == n - 1) + return pValue; + } + return Value(); } REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse); diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 8c326db6360..cbcf74dd618 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -1601,13 +1601,14 @@ public: }; -class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> { +class ExpressionIfNull final : public ExpressionVariadic<ExpressionIfNull> { public: explicit ExpressionIfNull(ExpressionContext* const expCtx) - : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {} + : ExpressionVariadic<ExpressionIfNull>(expCtx) {} Value evaluate(const Document& root, Variables* variables) const final; const char* getOpName() const final; + void validateArguments(const ExpressionVector& args) const final; void acceptVisitor(ExpressionVisitor* visitor) final { return visitor->visit(this); diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 938982b2fd2..ddecb47bb38 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -376,7 +376,9 @@ public: void visit(ExpressionFieldPath* expr) final {} void visit(ExpressionFilter* expr) final {} void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -545,7 +547,9 @@ public: _context->evalStack.emplaceFrame(EvalStage{}); } void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -1813,21 +1817,46 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr))); } void visit(ExpressionIfNull* expr) final { - _context->ensureArity(2); + auto numChildren = expr->getChildren().size(); + invariant(numChildren >= 2); - auto replacementIfNull = _context->popExpr(); - auto input = _context->popExpr(); + std::vector<EvalExprStagePair> branches; + branches.reserve(numChildren); + for (size_t i = 0; i < numChildren; ++i) { + auto [expr, stage] = _context->popFrame(); + branches.emplace_back(std::move(expr), std::move(stage)); + } + std::reverse(branches.begin(), branches.end()); - auto frameId = _context->frameIdGenerator->generate(); - auto binds = sbe::makeEs(std::move(input)); - sbe::EVariable inputRef(frameId, 0); + // Prepare to create limit-1/union with N branches (where N is the number of operands). Each + // branch will be evaluated from left to right until one of the branches produces a value. + auto branchFn = [](EvalExpr evalExpr, + EvalStage stage, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + auto slot = slotIdGenerator->generate(); + stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr()); + + // Create a FilterStage for each branch (except the last one). If a branch's filter + // condition is true, it will "short-circuit" the evaluation process. For ifNull, + // short-circuiting should happen if the current variable is not null or missing. + auto filterExpr = makeNot(generateNullOrMissing(slot)); + auto filterStage = + makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId); + + // Set the current expression as the output to be returned if short-circuiting occurs. + return std::make_pair(slot, std::move(filterStage)); + }; - // If input is null or missing, return replacement expression. Otherwise, return input. - auto ifNullExpr = sbe::makeE<sbe::EIf>( - generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone()); + auto [resultExpr, opStage] = generateSingleResultUnion( + std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator); - _context->pushExpr( - sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr))); + auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(), + std::move(opStage), + _context->planNodeId, + _context->getLexicalEnvironment()); + + _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage)); } void visit(ExpressionIn* expr) final { unsupportedExpression(expr->getOpName()); |