diff options
Diffstat (limited to 'src/mongo/db/query')
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 55 |
1 files changed, 42 insertions, 13 deletions
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 938982b2fd2..ddecb47bb38 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -376,7 +376,9 @@ public: void visit(ExpressionFieldPath* expr) final {} void visit(ExpressionFilter* expr) final {} void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -545,7 +547,9 @@ public: _context->evalStack.emplaceFrame(EvalStage{}); } void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -1813,21 +1817,46 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr))); } void visit(ExpressionIfNull* expr) final { - _context->ensureArity(2); + auto numChildren = expr->getChildren().size(); + invariant(numChildren >= 2); - auto replacementIfNull = _context->popExpr(); - auto input = _context->popExpr(); + std::vector<EvalExprStagePair> branches; + branches.reserve(numChildren); + for (size_t i = 0; i < numChildren; ++i) { + auto [expr, stage] = _context->popFrame(); + branches.emplace_back(std::move(expr), std::move(stage)); + } + std::reverse(branches.begin(), branches.end()); - auto frameId = _context->frameIdGenerator->generate(); - auto binds = sbe::makeEs(std::move(input)); - sbe::EVariable inputRef(frameId, 0); + // Prepare to create limit-1/union with N branches (where N is the number of operands). Each + // branch will be evaluated from left to right until one of the branches produces a value. + auto branchFn = [](EvalExpr evalExpr, + EvalStage stage, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + auto slot = slotIdGenerator->generate(); + stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr()); + + // Create a FilterStage for each branch (except the last one). If a branch's filter + // condition is true, it will "short-circuit" the evaluation process. For ifNull, + // short-circuiting should happen if the current variable is not null or missing. + auto filterExpr = makeNot(generateNullOrMissing(slot)); + auto filterStage = + makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId); + + // Set the current expression as the output to be returned if short-circuiting occurs. + return std::make_pair(slot, std::move(filterStage)); + }; - // If input is null or missing, return replacement expression. Otherwise, return input. - auto ifNullExpr = sbe::makeE<sbe::EIf>( - generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone()); + auto [resultExpr, opStage] = generateSingleResultUnion( + std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator); - _context->pushExpr( - sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr))); + auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(), + std::move(opStage), + _context->planNodeId, + _context->getLexicalEnvironment()); + + _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage)); } void visit(ExpressionIn* expr) final { unsupportedExpression(expr->getOpName()); |