diff options
author | Alya Berciu <alyacarina@gmail.com> | 2020-11-16 14:17:01 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-11-23 16:21:33 +0000 |
commit | 4cf9000c24591166f4c093f4702a522a4a62097f (patch) | |
tree | 0e4638e715b8b1051bf2d10097e1c27742af444a | |
parent | 71fb74aab300a852761e0ae3b0782c207f4aef52 (diff) | |
download | mongo-4cf9000c24591166f4c093f4702a522a4a62097f.tar.gz |
SERVER-51534 Support concatArrays in SBE
-rw-r--r-- | jstests/aggregation/bugs/server14872.js | 9 | ||||
-rw-r--r-- | jstests/aggregation/expressions/concat_arrays.js | 154 | ||||
-rw-r--r-- | jstests/libs/sbe_assert_error_override.js | 1 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/stages/hash_agg.cpp | 5 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 345 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.cpp | 60 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.h | 48 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_projection.cpp | 2 |
8 files changed, 488 insertions, 136 deletions
diff --git a/jstests/aggregation/bugs/server14872.js b/jstests/aggregation/bugs/server14872.js index 8f4b84701da..f3e4abf85b6 100644 --- a/jstests/aggregation/bugs/server14872.js +++ b/jstests/aggregation/bugs/server14872.js @@ -1,14 +1,11 @@ // SERVER-14872: Aggregation expression to concatenate multiple arrays into one -// @tags: [ -// sbe_incompatible, -// ] - -// For assertErrorCode. -load('jstests/aggregation/extras/utils.js'); (function() { 'use strict'; +load('jstests/aggregation/extras/utils.js'); // For assertErrorCode. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. + var coll = db.agg_concat_arrays_expr; coll.drop(); diff --git a/jstests/aggregation/expressions/concat_arrays.js b/jstests/aggregation/expressions/concat_arrays.js new file mode 100644 index 00000000000..37f4b6c4a72 --- /dev/null +++ b/jstests/aggregation/expressions/concat_arrays.js @@ -0,0 +1,154 @@ +// Confirm correctness of $concatArrays expression evaluation. + +(function() { +"use strict"; + +load("jstests/aggregation/extras/utils.js"); // For assertArrayEq. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. + +const coll = db.projection_expr_concat_arrays; +coll.drop(); + +assert.commandWorked(coll.insertOne({ + int_arr: [1, 2, 3, 4], + dbl_arr: [10.0, 20.1, 20.4, 50.5], + nested_arr: [["an", "array"], "arr", [[], [[], "a", "b"]]], + str_arr: ["a", "b", "c"], + obj_arr: [{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}], + null_arr: [null, null, null], + one_null_arr: [null], + one_str_arr: ["one"], + empty_arr: [], + null_val: null, + str_val: "a string", + dbl_val: 2.0, + int_val: 1, + obj_val: {a: 1, b: "two"} +})); + +function runAndAssert(operands, expectedResult) { + assertArrayEq({ + actual: coll.aggregate([{$project: {f: {$concatArrays: operands}}}]).map(doc => doc.f), + expected: expectedResult + }); +} + +function runAndAssertNull(operands) { + runAndAssert(operands, [null]); +} + +function runAndAssertThrows(operands) { + const error = + assert.throws(() => coll.aggregate([{$project: {f: {$concatArrays: operands}}}]).toArray()); + assert.commandFailedWithCode(error, 28664); +} + +runAndAssert(["$int_arr"], [[1, 2, 3, 4]]); +runAndAssert([[0], "$int_arr", [5, 6, 7]], [[0, 1, 2, 3, 4, 5, 6, 7]]); +runAndAssert(["$int_arr", "$str_arr"], [[1, 2, 3, 4, "a", "b", "c"]]); +runAndAssert( + ["$obj_arr", "$obj_arr", "$null_arr"], + [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, {a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, null, null, null]]); +runAndAssert(["$int_arr", "$str_arr", "$nested_arr"], + [[1, 2, 3, 4, "a", "b", "c", ["an", "array"], "arr", [[], [[], "a", "b"]]]]); +runAndAssert(["$int_arr", "$obj_arr"], [[1, 2, 3, 4, {a: 1, b: 2}, {c: 3}, {d: 4, e: 5}]]); +runAndAssert(["$obj_arr"], [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}]]); +runAndAssert(["$obj_arr", [{o: 123, b: 1}, {y: "o", d: "a"}]], + [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, {o: 123, b: 1}, {y: "o", d: "a"}]]); + +// Confirm that arrays containing null can be concatenated. +runAndAssert(["$null_arr"], [[null, null, null]]); +runAndAssert([[null], "$null_arr"], [[null, null, null, null]]); +runAndAssert("$one_null_arr", [[null]]); +runAndAssert(["$null_arr", "$one_null_arr", "$int_arr", "$null_arr"], + [[null, null, null, null, 1, 2, 3, 4, null, null, null]]); + +// Test operands that form more complex expressions. +runAndAssert([{$concatArrays: "$int_arr"}], [[1, 2, 3, 4]]); +runAndAssert([{$concatArrays: "$int_arr"}, {$concatArrays: {$concatArrays: "$str_arr"}}], + [[1, 2, 3, 4, "a", "b", "c"]]); +runAndAssert(["$str_arr", {$filter: {input: "$int_arr", + as: "num", + cond: { $and: [ + { $gte: [ "$$num", 2 ] }, + { $lte: [ "$$num", 3 ] } + ] }}}, "$int_arr"], + [["a", "b", "c", 2, 3, 1, 2, 3, 4]]); + +// Confirm that having any combination of null or missing inputs and valid inputs produces null. +runAndAssertNull(["$int_arr", "$null_val"]); +runAndAssertNull(["$int_arr", null]); +runAndAssertNull([null, "$int_arr", "$str_arr"]); +runAndAssertNull(["$int_arr", null, "$str_arr"]); +runAndAssertNull(["$null_val", "$str_arr", "$int_arr"]); +runAndAssertNull(["$str_arr", "$null_val", "$int_arr"]); +runAndAssertNull(["$int_arr", "$not_a_field"]); +runAndAssertNull(["$not_a_field", "$str_arr", "$int_arr"]); +runAndAssertNull(["$not_a_field"]); +runAndAssertNull(["$null_val"]); +runAndAssertNull(["$not_a_field", "$null_val"]); +runAndAssertNull(["$null_val", "$not_a_field"]); +runAndAssertNull([ + {$concatArrays: "$int_arr"}, + null, + {$concatArrays: {$concatArrays: ["$obj_arr", "$str_arr"]}} +]); + +// Confirm edge case where if null precedes non-array input, null is returned. +runAndAssertNull(["$int_arr", "$null_val", "$int_val"]); + +// +// Confirm error cases. +// + +// Confirm concatenating non-array and non-values produces an error. +runAndAssertThrows(["$dbl_val"]); +runAndAssertThrows(["$str_val"]); +runAndAssertThrows(["$int_val"]); +runAndAssertThrows([123]); +runAndAssertThrows(["some_val", [1, 2, 3]]); +runAndAssertThrows(["$obj_val"]); +runAndAssertThrows(["$int_arr", "$int_val"]); +runAndAssertThrows(["$dbl_arr", "$dbl_val"]); + +// Confirm edge case where if invalid input precedes null or missing inputs, the command fails. +runAndAssertThrows(["$int_arr", "$dbl_val", "$null_val"]); +runAndAssertThrows(["$int_arr", "some_string_value", "$null_val"]); +runAndAssertThrows(["$int_arr", 32]); +runAndAssertThrows(["$dbl_val", "$null_val"]); +runAndAssertThrows(["$int_arr", "$int_val", "$not_a_field"]); +runAndAssertThrows(["$int_val", "$not_a_field"]); +runAndAssertThrows(["$int_val", "$not_a_field", "$null_val"]); + +// Clear collection. +assert(coll.drop()); + +// Test case where find returns multiple documents. +assert.commandWorked(coll.insertMany([ + {arr1: [42, 35.0, 197865432], arr2: ["albatross", "abbacus", "alien"]}, + {arr1: [1], arr2: ["albatross", "abbacus", "alien"]}, + {arr1: [1, 2, 3, 4, 5, 6, 11, 12, 23], arr2: []}, + {arr1: [], arr2: ["foo", "bar"]}, + {arr1: [], arr2: []}, + {arr1: [1, 2, 3, 4, 5, 6, 11, 12, 23], arr2: null}, + {some_field: "foo"}, +])); +runAndAssert(["$arr1", "$arr2"], [ + [42, 35.0, 197865432, "albatross", "abbacus", "alien"], + [1, "albatross", "abbacus", "alien"], + [1, 2, 3, 4, 5, 6, 11, 12, 23], + ["foo", "bar"], + [], + null, + null +]); +runAndAssert(["$arr1", [1, 2, 3], "$arr2"], [ + [42, 35.0, 197865432, 1, 2, 3, "albatross", "abbacus", "alien"], + [1, 1, 2, 3, "albatross", "abbacus", "alien"], + [1, 2, 3, 4, 5, 6, 11, 12, 23, 1, 2, 3], + ["foo", 1, 2, 3, "bar"], + [1, 2, 3], + null, + null +]); +}()); diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js index ce655b77aa2..cd98fec00fe 100644 --- a/jstests/libs/sbe_assert_error_override.js +++ b/jstests/libs/sbe_assert_error_override.js @@ -32,6 +32,7 @@ const equivalentErrorCodesList = [ [16609, 5073101], [16610, 4848403], [16555, 5073102], + [28664, 5153400], [28680, 4903701], [28689, 5126701], [28690, 5126702], diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.cpp b/src/mongo/db/exec/sbe/stages/hash_agg.cpp index fc97e86426d..c9f6da29462 100644 --- a/src/mongo/db/exec/sbe/stages/hash_agg.cpp +++ b/src/mongo/db/exec/sbe/stages/hash_agg.cpp @@ -105,6 +105,10 @@ void HashAggStage::open(bool reOpen) { _commonStats.opens++; _children[0]->open(reOpen); + if (reOpen) { + _ht.clear(); + } + while (_children[0]->getNext() == PlanState::ADVANCED) { value::MaterializedRow key{_inKeyAccessors.size()}; // Copy keys in order to do the lookup. @@ -161,6 +165,7 @@ const SpecificStats* HashAggStage::getSpecificStats() const { void HashAggStage::close() { _commonStats.closes++; + _ht.clear(); } std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const { diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 42b7a086e90..731b30e86d9 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -36,6 +36,7 @@ #include "mongo/db/exec/sbe/stages/branch.h" #include "mongo/db/exec/sbe/stages/co_scan.h" #include "mongo/db/exec/sbe/stages/filter.h" +#include "mongo/db/exec/sbe/stages/hash_agg.h" #include "mongo/db/exec/sbe/stages/limit_skip.h" #include "mongo/db/exec/sbe/stages/loop_join.h" #include "mongo/db/exec/sbe/stages/project.h" @@ -182,12 +183,11 @@ std::pair<sbe::value::SlotId, EvalStage> generateTraverseHelper( std::move(inputStage), planNodeId, fieldSlot, - sbe::makeE<sbe::EFunction>( - "getField"sv, - sbe::makeEs(sbe::makeE<sbe::EVariable>(inputSlot), sbe::makeE<sbe::EConstant>([&]() { - auto fieldName = fp.getFieldName(level); - return std::string_view{fieldName.rawData(), fieldName.size()}; - }())))); + makeFunction( + "getField"sv, sbe::makeE<sbe::EVariable>(inputSlot), sbe::makeE<sbe::EConstant>([&]() { + auto fieldName = fp.getFieldName(level); + return std::string_view{fieldName.rawData(), fieldName.size()}; + }()))); EvalStage innerBranch; if (level == fp.getPathLength() - 1) { @@ -283,9 +283,7 @@ void generateStringCaseConversionExpression(ExpressionVisitorContext* _context, auto caseConversionExpr = sbe::makeE<sbe::EIf>( std::move(checkValidTypeExpr), - sbe::makeE<sbe::EFunction>(caseConversionFunction, - sbe::makeEs(sbe::makeE<sbe::EFunction>( - "coerceToString", sbe::makeEs(inputRef.clone())))), + makeFunction(caseConversionFunction, makeFunction("coerceToString", inputRef.clone())), sbe::makeE<sbe::EFail>(ErrorCodes::Error{5066300}, str::stream() << "$" << caseConversionFunction << " input type is not supported")); @@ -311,16 +309,14 @@ void buildArrayAccessByConstantIndex(ExpressionVisitorContext* context, auto indexExpr = sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom<int32_t>(index)); - auto argumentIsNotArray = - makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(arrayRef.clone()))); + auto argumentIsNotArray = makeNot(makeFunction("isArray", arrayRef.clone())); auto resultExpr = buildMultiBranchConditional( CaseValuePair{generateNullOrMissing(arrayRef), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, CaseValuePair{std::move(argumentIsNotArray), sbe::makeE<sbe::EFail>(ErrorCodes::Error{5126704}, exprName + " argument must be an array")}, - sbe::makeE<sbe::EFunction>("getElement", - sbe::makeEs(arrayRef.clone(), std::move(indexExpr)))); + makeFunction("getElement", arrayRef.clone(), std::move(indexExpr))); context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(resultExpr))); @@ -364,7 +360,9 @@ public: void visit(ExpressionCoerceToBool* expr) final {} void visit(ExpressionCompare* expr) final {} void visit(ExpressionConcat* expr) final {} - void visit(ExpressionConcatArrays* expr) final {} + void visit(ExpressionConcatArrays* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionCond* expr) final { _context->evalStack.emplaceFrame(EvalStage{}); } @@ -519,7 +517,9 @@ public: void visit(ExpressionCoerceToBool* expr) final {} void visit(ExpressionCompare* expr) final {} void visit(ExpressionConcat* expr) final {} - void visit(ExpressionConcatArrays* expr) final {} + void visit(ExpressionConcatArrays* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionCond* expr) final { _context->evalStack.emplaceFrame(EvalStage{}); } @@ -724,7 +724,7 @@ public: CaseValuePair{generateLongLongMinCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903701}, "can't take $abs of long long min")}, - sbe::makeE<sbe::EFunction>("abs", sbe::makeEs(inputRef.clone()))); + makeFunction("abs", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(absExpr))); @@ -739,12 +739,10 @@ public: sbe::EVariable var{frameId, slotId}; return sbe::makeE<sbe::EPrimBinary>( sbe::EPrimBinary::logicAnd, - sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(var.clone()))), - sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(var.clone())))); + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isNumber", var.clone())), + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isDate", var.clone()))); }; if (arity == 2) { @@ -767,10 +765,9 @@ public: ErrorCodes::Error{4974201}, "only numbers and dates are allowed in an $add expression"), sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::logicAnd, - sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(lhsVar.clone())), - sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(rhsVar.clone()))), + sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicAnd, + makeFunction("isDate", lhsVar.clone()), + makeFunction("isDate", rhsVar.clone())), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202}, "only one date allowed in an $add expression"), sbe::makeE<sbe::EPrimBinary>( @@ -863,7 +860,7 @@ public: sbe::EVariable convertedIndexRef{frameId, 0}; auto inExpression = sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(convertedIndexRef.clone())), + makeFunction("exists", convertedIndexRef.clone()), convertedIndexRef.clone(), sbe::makeE<sbe::EFail>( ErrorCodes::Error{5126703}, @@ -876,8 +873,7 @@ public: sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr, generateNullOrMissing(arrayRef), generateNullOrMissing(indexRef)); - auto firstArgumentIsNotArray = - makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(arrayRef.clone()))); + auto firstArgumentIsNotArray = makeNot(makeFunction("isArray", arrayRef.clone())); auto secondArgumentIsNotNumeric = generateNonNumericCheck(indexRef); auto arrayElemAtExpr = buildMultiBranchConditional( CaseValuePair{std::move(anyOfArgumentsIsNullish), @@ -888,8 +884,7 @@ public: CaseValuePair{std::move(secondArgumentIsNotNumeric), sbe::makeE<sbe::EFail>(ErrorCodes::Error{5126702}, "$arrayElemAt second argument must be a number")}, - sbe::makeE<sbe::EFunction>("getElement", - sbe::makeEs(arrayRef.clone(), std::move(int32Index)))); + makeFunction("getElement", arrayRef.clone(), std::move(int32Index))); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(arrayElemAtExpr))); @@ -923,7 +918,7 @@ public: CaseValuePair{generateNonObjectCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{5043001}, "$bsonSize requires a document input")}, - sbe::makeE<sbe::EFunction>("bsonSize", sbe::makeEs(inputRef.clone()))); + makeFunction("bsonSize", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(bsonSizeExpr))); @@ -939,7 +934,7 @@ public: CaseValuePair{generateNonNumericCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903702}, "$ceil only supports numeric types")}, - sbe::makeE<sbe::EFunction>("ceil", sbe::makeEs(inputRef.clone()))); + makeFunction("ceil", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ceilExpr))); @@ -997,13 +992,13 @@ public: // will also evaluate to "Nothing." MQL comparisons, however, treat "Nothing" as if it is a // value that is less than everything other than MinKey. (Notably, two expressions that // evaluate to "Nothing" are considered equal to each other.) - auto nothingFallbackCmp = sbe::makeE<sbe::EPrimBinary>( - comparisonOperator, - sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(lhsRef.clone())), - sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(rhsRef.clone()))); + auto nothingFallbackCmp = + sbe::makeE<sbe::EPrimBinary>(comparisonOperator, + makeFunction("exists", lhsRef.clone()), + makeFunction("exists", rhsRef.clone())); - auto cmpWithFallback = sbe::makeE<sbe::EFunction>( - "fillEmpty", sbe::makeEs(std::move(cmp), std::move(nothingFallbackCmp))); + auto cmpWithFallback = + makeFunction("fillEmpty", std::move(cmp), std::move(nothingFallbackCmp)); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(operands), std::move(cmpWithFallback))); @@ -1023,8 +1018,7 @@ public: sbe::EVariable var(frameId, slot); binds.push_back(_context->popExpr()); checkNullArg.push_back(generateNullOrMissing(frameId, slot)); - checkStringArg.push_back( - sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(var.clone()))); + checkStringArg.push_back(makeFunction("isString", var.clone())); argVars.push_back(var.clone()); } std::reverse(std::begin(binds), std::end(binds)); @@ -1060,7 +1054,132 @@ public: } void visit(ExpressionConcatArrays* expr) final { - unsupportedExpression(expr->getOpName()); + // Pop eval frames pushed by pre and in visitors off the stack. + std::vector<EvalExprStagePair> branches; + auto numChildren = expr->getChildren().size(); + branches.reserve(numChildren); + for (size_t idx = 0; idx < numChildren; ++idx) { + auto [branchExpr, branchEvalStage] = _context->popFrame(); + branches.emplace_back(std::move(branchExpr), std::move(branchEvalStage)); + } + std::reverse(branches.begin(), branches.end()); + + auto getUnionOutputSlot = [](EvalExpr& unionEvalExpr) { + auto slot = *(unionEvalExpr.getSlot()); + invariant(slot); + return slot; + }; + + auto makeNullLimitCoscanTree = [&]() { + auto outputSlot = _context->slotIdGenerator->generate(); + auto nullEvalStage = + makeProject({makeLimitCoScanTree(_context->planNodeId), sbe::makeSV()}, + _context->planNodeId, + outputSlot, + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)); + return EvalExprStagePair{outputSlot, std::move(nullEvalStage)}; + }; + + // Build a union stage to consolidate array input branches into a stream. + auto [unionEvalExpr, unionEvalStage] = + generateUnion(std::move(branches), {}, _context->planNodeId, _context->slotIdGenerator); + auto unionSlot = getUnionOutputSlot(unionEvalExpr); + sbe::EVariable unionVar{unionSlot}; + + // Filter stage to EFail if an element is not an array, null, or missing, and EOF if an + // element is null or missing: not(isNullOrMissing) && (isArray || EFail). + auto filterExpr = sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::logicAnd, + makeNot(generateNullOrMissing(unionVar)), + sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::logicOr, + makeFunction("isArray", unionVar.clone()), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5153400}, + "$concatArrays only supports arrays"))); + auto filter = makeFilter<false, true>( + std::move(unionEvalStage), std::move(filterExpr), _context->planNodeId); + + // Create a union stage to replace any values filtered out by the previous stage with null. + // For example, [a, b, null, c, d] would become [a, b, null]. + std::vector<EvalExprStagePair> unionWithNullBranches; + unionWithNullBranches.emplace_back(sbe::makeE<sbe::EVariable>(unionSlot), + std::move(filter)); + unionWithNullBranches.emplace_back(makeNullLimitCoscanTree()); + auto [unionWithNullExpr, unionWithNullStage] = generateUnion( + std::move(unionWithNullBranches), {}, _context->planNodeId, _context->slotIdGenerator); + auto unionWithNullSlot = getUnionOutputSlot(unionWithNullExpr); + + // Create a limit stage to EOF once numChildren results have been obtained. + auto limitNumChildren = + makeLimitTree(std::move(unionWithNullStage.stage), _context->planNodeId, numChildren); + + // Create a group stage to aggregate elements into a single array. + auto addToArrayExpr = + makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unionWithNullSlot)); + auto groupSlot = _context->slotIdGenerator->generate(); + auto groupStage = + sbe::makeS<sbe::HashAggStage>(std::move(limitNumChildren), + sbe::makeSV(), + sbe::makeEM(groupSlot, std::move(addToArrayExpr)), + _context->planNodeId); + EvalStage groupEvalStage = {std::move(groupStage), sbe::makeSV(groupSlot)}; + + // Build subtree to handle nulls. If an input is null, return null. Otherwise, unwind the + // input twice, and concatenate it into an array using addToArray. This is necessary to + // implement the MQL behavior where one null or missing input results in a null output. + + // Create two unwind stages to unwind the array that was built from inputs + // and unwind each input array into its constituent elements. We need a limit 1/coscan stage + // here to call getNext() on, but we use the output slot of groupStage to obtain the array + // of inputs. + auto unwindEvalStage = makeUnwind( + makeUnwind({makeLimitCoScanStage(_context->planNodeId).stage, sbe::makeSV(groupSlot)}, + _context->slotIdGenerator, + _context->planNodeId), + _context->slotIdGenerator, + _context->planNodeId); + auto unwindSlot = unwindEvalStage.outSlots.front(); + + // Create a group stage to append all streamed elements into one array. This is the final + // output when the input consists entirely of arrays. + auto finalAddToArrayExpr = + makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unwindSlot)); + auto finalGroupSlot = _context->slotIdGenerator->generate(); + auto finalGroupStage = sbe::makeS<sbe::HashAggStage>( + std::move(unwindEvalStage.stage), + sbe::makeSV(), + sbe::makeEM(finalGroupSlot, std::move(finalAddToArrayExpr)), + _context->planNodeId); + + // Create a branch stage to select between the branch that produces one null if any eleemnts + // in the original input were null or missing, or otherwise select the branch that unwinds + // and concatenates elements into the output array. + auto [nullExpr, nullStage] = makeNullLimitCoscanTree(); + auto nullIsMemberExpr = + makeFunction("isMember", + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0), + sbe::makeE<sbe::EVariable>(groupSlot)); + auto branchNullEvalStage = + makeBranch(std::move(nullIsMemberExpr), + std::move(nullStage), + {std::move(finalGroupStage), sbe::makeSV(finalGroupSlot)}, + _context->slotIdGenerator, + _context->planNodeId); + auto branchSlot = branchNullEvalStage.outSlots.front(); + + // Create nlj to connect outer group with inner branch that handles null input. + auto nljStage = makeLoopJoin(std::move(groupEvalStage), + std::move(branchNullEvalStage), + _context->planNodeId, + _context->getLexicalEnvironment()); + + // Top level nlj to inject input slots. + auto finalNljStage = makeLoopJoin(_context->extractCurrentEvalStage(), + std::move(nljStage), + _context->planNodeId, + _context->getLexicalEnvironment()); + + _context->pushExpr(sbe::makeE<sbe::EVariable>(branchSlot), std::move(finalNljStage)); } void visit(ExpressionCond* expr) final { visitConditionalExpression(expr); @@ -1174,19 +1293,16 @@ public: sbe::makeE<sbe::EIf>( sbe::makeE<sbe::EPrimBinary>( sbe::EPrimBinary::logicOr, - sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("exists", - sbe::makeEs(outerSlotRef.clone()))), - sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(outerSlotRef.clone()))), + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("exists", outerSlotRef.clone())), + makeFunction("isNull", outerSlotRef.clone())), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0), sbe::makeE<sbe::ELocalBind>( innerFrameId, sbe::makeEs(sbe::makeE<sbe::ENumericConvert>( outerSlotRef.clone(), sbe::value::TypeTags::NumberInt64)), sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EFunction>("exists", - sbe::makeEs(convertedFieldRef.clone())), + makeFunction("exists", convertedFieldRef.clone()), convertedFieldRef.clone(), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4848979}, str::stream() @@ -1307,7 +1423,7 @@ public: tzFrameId, sbe::makeEs(std::move(eTimezone)), sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(timeZoneRef.clone())), + makeFunction("isString", timeZoneRef.clone()), timezoneRef.clone(), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4848980}, str::stream() @@ -1342,17 +1458,16 @@ public: // runtime environment so we pass the corresponding slot to the datePartsWeekYear and // dateParts functions as a variable. auto timeZoneDBSlot = _context->runtimeEnvironment->getSlot("timeZoneDB"_sd); - auto computeDate = - sbe::makeE<sbe::EFunction>(isIsoWeekYear ? "datePartsWeekYear" : "dateParts", - sbe::makeEs(sbe::makeE<sbe::EVariable>(timeZoneDBSlot), - yearRef.clone(), - monthRef.clone(), - dayRef.clone(), - hourRef.clone(), - minRef.clone(), - secRef.clone(), - millisecRef.clone(), - timeZoneRef.clone())); + auto computeDate = makeFunction(isIsoWeekYear ? "datePartsWeekYear" : "dateParts", + sbe::makeE<sbe::EVariable>(timeZoneDBSlot), + yearRef.clone(), + monthRef.clone(), + dayRef.clone(), + hourRef.clone(), + minRef.clone(), + secRef.clone(), + millisecRef.clone(), + timeZoneRef.clone()); using iterPair_t = std::vector<std::pair<std::unique_ptr<sbe::EExpression>, std::unique_ptr<sbe::EExpression>>>::iterator; @@ -1449,19 +1564,17 @@ public: CaseValuePair{generateNullOrMissing(frameId, 1), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, CaseValuePair{ - sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(timezoneRef.clone()))), + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isString", timezoneRef.clone())), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4997701}, "$dateToParts timezone must be a string")}, CaseValuePair{ sbe::makeE<sbe::EPrimUnary>( sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>( - "isTimezone", - sbe::makeEs(sbe::makeE<sbe::EVariable>( - _context->runtimeEnvironment->getSlot("timeZoneDB"_sd)), - timezoneRef.clone()))), + makeFunction("isTimezone", + sbe::makeE<sbe::EVariable>( + _context->runtimeEnvironment->getSlot("timeZoneDB"_sd)), + timezoneRef.clone())), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4997704}, "$dateToParts timezone must be a valid timezone")}, CaseValuePair{generateNullOrMissing(frameId, 2), @@ -1497,10 +1610,9 @@ public: sbe::EVariable lhsRef{frameId, 0}; sbe::EVariable rhsRef{frameId, 1}; - auto checkIsNumber = sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::logicAnd, - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(lhsRef.clone())), - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(rhsRef.clone()))); + auto checkIsNumber = sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicAnd, + makeFunction("isNumber", lhsRef.clone()), + makeFunction("isNumber", rhsRef.clone())); auto checkIsNullOrMissing = sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr, generateNullOrMissing(lhsRef), @@ -1529,7 +1641,7 @@ public: CaseValuePair{generateNonNumericCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903703}, "$exp only supports numeric types")}, - sbe::makeE<sbe::EFunction>("exp", sbe::makeEs(inputRef.clone()))); + makeFunction("exp", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(expExpr))); @@ -1605,10 +1717,10 @@ public: auto binds = sbe::makeEs(std::move(input)); sbe::EVariable inputRef(frameId, 0); - auto inputIsArrayOrNullish = sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::logicOr, - generateNullOrMissing(inputRef), - sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(inputRef.clone()))); + auto inputIsArrayOrNullish = + sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr, + generateNullOrMissing(inputRef), + makeFunction("isArray", inputRef.clone())); auto checkInputArrayType = sbe::makeE<sbe::EIf>(std::move(inputIsArrayOrNullish), inputRef.clone(), @@ -1673,10 +1785,9 @@ public: // If input array is null or missing, 'in' stage of traverse will return EOF. In this case // traverse sets output slot (filteredArraySlot) to Nothing. We replace it with Null to // match $filter expression behaviour. - auto result = sbe::makeE<sbe::EFunction>( - "fillEmpty", - sbe::makeEs(sbe::makeE<sbe::EVariable>(filteredArraySlot), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0))); + auto result = makeFunction("fillEmpty", + sbe::makeE<sbe::EVariable>(filteredArraySlot), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)); _context->pushExpr(std::move(result), std::move(traverseStage)); } @@ -1691,7 +1802,7 @@ public: CaseValuePair{generateNonNumericCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903704}, "$floor only supports numeric types")}, - sbe::makeE<sbe::EFunction>("floor", sbe::makeEs(inputRef.clone()))); + makeFunction("floor", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr))); @@ -1732,11 +1843,11 @@ public: auto binds = sbe::makeEs(_context->popExpr()); sbe::EVariable inputRef(frameId, 0); - auto exprIsNum = sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(inputRef.clone())), - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone())), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean, - sbe::value::bitcastFrom<bool>(false))); + auto exprIsNum = + sbe::makeE<sbe::EIf>(makeFunction("exists", inputRef.clone()), + makeFunction("isNumber", inputRef.clone()), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean, + sbe::value::bitcastFrom<bool>(false))); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(exprIsNum))); @@ -1786,7 +1897,7 @@ public: CaseValuePair{generateNonPositiveCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903706}, "$ln's argument must be a positive number")}, - sbe::makeE<sbe::EFunction>("ln", sbe::makeEs(inputRef.clone()))); + makeFunction("ln", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr))); @@ -1813,7 +1924,7 @@ public: CaseValuePair{generateNonPositiveCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903708}, "$log10's argument must be a positive number")}, - sbe::makeE<sbe::EFunction>("log10", sbe::makeEs(inputRef.clone()))); + makeFunction("log10", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(log10Expr))); @@ -1845,8 +1956,7 @@ public: sbe::EPrimUnary::logicNot, sbe::makeE<sbe::ETypeMatch>( lhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)))), - sbe::makeE<sbe::EFunction>( - "fillEmpty", sbe::makeEs(std::move(numericConvert32), rhsVar.clone()))}, + makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())}, rhsVar.clone()); auto modExpr = buildMultiBranchConditional( @@ -1859,7 +1969,7 @@ public: generateNonNumericCheck(rhsVar)), sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000}, "$mod only supports numeric types")}, - sbe::makeE<sbe::EFunction>("mod", sbe::makeEs(lhsVar.clone(), std::move(rhsExpr)))); + makeFunction("mod", lhsVar.clone(), std::move(rhsExpr))); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(modExpr))); @@ -1884,8 +1994,7 @@ public: variables.push_back(currentVariable.clone()); checkExprsNull.push_back(generateNullOrMissing(currentVariable)); - checkExprsNumber.push_back( - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(currentVariable.clone()))); + checkExprsNumber.push_back(makeFunction("isNumber", currentVariable.clone())); } // At this point 'binds' vector contains arguments of $multiply expression in the reversed @@ -2076,7 +2185,7 @@ public: generateNegativeCheck(inputRef), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903710}, "$sqrt's argument must be greater than or equal to 0")}, - sbe::makeE<sbe::EFunction>("sqrt", sbe::makeEs(inputRef.clone()))); + makeFunction("sqrt", inputRef.clone())); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr))); @@ -2438,15 +2547,13 @@ private: sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998200}, str::stream() << "$" << exprName.toString() << " timezone must be a string")}, - CaseValuePair{sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>( - "isTimezone", - sbe::makeEs(sbe::makeE<sbe::EVariable>(timeZoneDBSlot), - timezoneRef.clone()))), - sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998201}, - str::stream() - << "$" << exprName.toString() + CaseValuePair{ + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isTimezone", + sbe::makeE<sbe::EVariable>(timeZoneDBSlot), + timezoneRef.clone())), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998201}, + str::stream() << "$" << exprName.toString() << " timezone must be a valid timezone")}, CaseValuePair{generateNullOrMissing(dateRef), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, @@ -2474,12 +2581,12 @@ private: auto genericTrignomentricExpr = sbe::makeE<sbe::EIf>( generateNullOrMissing(frameId, 0), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0), - sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone())), - sbe::makeE<sbe::EFunction>(exprName.toString(), sbe::makeEs(inputRef.clone())), - sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995501}, - str::stream() << "$" << exprName.toString() - << " supports only numeric types"))); + sbe::makeE<sbe::EIf>(makeFunction("isNumber", inputRef.clone()), + makeFunction(exprName.toString(), inputRef.clone()), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995501}, + str::stream() + << "$" << exprName.toString() + << " supports only numeric types"))); _context->pushExpr(sbe::makeE<sbe::ELocalBind>( frameId, std::move(binds), std::move(genericTrignomentricExpr))); @@ -2517,15 +2624,14 @@ private: generateNullOrMissing(frameId, 0), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0), sbe::makeE<sbe::EIf>( - sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone()))), + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isNumber", inputRef.clone())), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995502}, str::stream() << "$" << exprName.toString() << " supports only numeric types"), sbe::makeE<sbe::EIf>( std::move(checkBounds), - sbe::makeE<sbe::EFunction>(exprName.toString(), sbe::makeEs(inputRef.clone())), + makeFunction(exprName.toString(), inputRef.clone()), sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995503}, str::stream() << "Cannot apply $" << exprName.toString() << ", value must be in " @@ -2661,7 +2767,7 @@ private: auto generateNotArray = [frameId](const sbe::value::SlotId slotId) { sbe::EVariable var{frameId, slotId}; - return makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(var.clone()))); + return makeNot(makeFunction("isArray", var.clone())); }; std::vector<std::unique_ptr<sbe::EExpression>> binds; @@ -2841,10 +2947,9 @@ std::unique_ptr<sbe::EExpression> generateCoerceToBoolExpression(sbe::EVariable // If any of these are false, the branch is considered false for the purposes of the // any logical expression. - auto checkExists = sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(branchRef.clone())); - auto checkNotNull = sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(branchRef.clone()))); + auto checkExists = makeFunction("exists", branchRef.clone()); + auto checkNotNull = sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + makeFunction("isNull", branchRef.clone())); auto checkNotFalse = makeNeqCheck(sbe::makeE<sbe::EConstant>( sbe::value::TypeTags::Boolean, sbe::value::bitcastFrom<bool>(false))); auto checkNotZero = makeNeqCheck(sbe::makeE<sbe::EConstant>( diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index 77a58529d35..7e7c45ded2a 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -32,6 +32,7 @@ #include "mongo/db/query/sbe_stage_builder_helpers.h" #include "mongo/db/exec/sbe/expressions/expression.h" +#include "mongo/db/exec/sbe/stages/branch.h" #include "mongo/db/exec/sbe/stages/co_scan.h" #include "mongo/db/exec/sbe/stages/limit_skip.h" #include "mongo/db/exec/sbe/stages/loop_join.h" @@ -126,6 +127,12 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( return defaultCase; } +std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage, + PlanNodeId planNodeId, + long long limit) { + return sbe::makeS<sbe::LimitSkipStage>(std::move(inputStage), limit, boost::none, planNodeId); +} + std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit) { return sbe::makeS<sbe::LimitSkipStage>( sbe::makeS<sbe::CoScanStage>(planNodeId), limit, boost::none, planNodeId); @@ -202,6 +209,36 @@ EvalStage makeLoopJoin(EvalStage left, std::move(outSlots)}; } +EvalStage makeUnwind(EvalStage inputEvalStage, + sbe::value::SlotIdGenerator* slotIdGenerator, + PlanNodeId planNodeId, + bool preserveNullAndEmptyArrays) { + auto unwindSlot = slotIdGenerator->generate(); + auto unwindStage = sbe::makeS<sbe::UnwindStage>(std::move(inputEvalStage.stage), + inputEvalStage.outSlots.front(), + unwindSlot, + slotIdGenerator->generate(), + preserveNullAndEmptyArrays, + planNodeId); + return {std::move(unwindStage), sbe::makeSV(unwindSlot)}; +} + +EvalStage makeBranch(std::unique_ptr<sbe::EExpression> ifExpr, + EvalStage thenStage, + EvalStage elseStage, + sbe::value::SlotIdGenerator* slotIdGenerator, + PlanNodeId planNodeId) { + auto outSlots = slotIdGenerator->generateMultiple(thenStage.outSlots.size()); + auto branchStage = sbe::makeS<sbe::BranchStage>(std::move(thenStage.stage), + std::move(elseStage.stage), + std::move(ifExpr), + thenStage.outSlots, + elseStage.outSlots, + outSlots, + planNodeId); + return {std::move(branchStage), std::move(outSlots)}; +} + EvalStage makeTraverse(EvalStage outer, EvalStage inner, sbe::value::SlotId inField, @@ -238,10 +275,10 @@ EvalStage makeTraverse(EvalStage outer, std::move(outSlots)}; } -EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branches, - BranchFn branchFn, - PlanNodeId planNodeId, - sbe::value::SlotIdGenerator* slotIdGenerator) { +EvalExprStagePair generateUnion(std::vector<EvalExprStagePair> branches, + BranchFn branchFn, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { std::vector<std::unique_ptr<sbe::PlanStage>> stages; std::vector<sbe::value::SlotVector> inputs; stages.reserve(branches.size()); @@ -266,13 +303,22 @@ EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branc auto outputSlot = slotIdGenerator->generate(); auto unionStage = sbe::makeS<sbe::UnionStage>( std::move(stages), std::move(inputs), sbe::makeSV(outputSlot), planNodeId); - EvalStage outputStage = { - sbe::makeS<sbe::LimitSkipStage>(std::move(unionStage), 1, boost::none, planNodeId), - sbe::makeSV(outputSlot)}; + EvalStage outputStage{std::move(unionStage), sbe::makeSV(outputSlot)}; return {outputSlot, std::move(outputStage)}; } +EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branches, + BranchFn branchFn, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + auto [unionEvalExpr, unionEvalStage] = + generateUnion(std::move(branches), std::move(branchFn), planNodeId, slotIdGenerator); + return {std::move(unionEvalExpr), + EvalStage{makeLimitTree(std::move(unionEvalStage.stage), planNodeId), + std::move(unionEvalStage.outSlots)}}; +} + EvalExprStagePair generateShortCircuitingLogicalOp(sbe::EPrimBinary::Op logicOp, std::vector<EvalExprStagePair> branches, PlanNodeId planNodeId, diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index 6ab645883bd..454665e65a7 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -124,6 +124,13 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( std::unique_ptr<sbe::EExpression> defaultCase); /** + * Insert a limit stage on top of the 'input' stage. + */ +std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage, + PlanNodeId planNodeId, + long long limit = 1); + +/** * Create tree consisting of coscan stage followed by limit stage. */ std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit = 1); @@ -150,6 +157,14 @@ std::unique_ptr<sbe::EExpression> makeNot(std::unique_ptr<sbe::EExpression> e); std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpression> e); /** + * Creates an EFunction expression with the given name and arguments. + */ +template <typename... Args> +inline std::unique_ptr<sbe::EExpression> makeFunction(std::string_view name, Args&&... args) { + return sbe::makeE<sbe::EFunction>(name, sbe::makeEs(std::forward<Args>(args)...)); +} + +/** * If given 'EvalExpr' already contains a slot, simply returns it. Otherwise, allocates a new slot * and creates project stage to assign expression to this new slot. After that, new slot and project * stage are returned. @@ -160,13 +175,13 @@ std::pair<sbe::value::SlotId, EvalStage> projectEvalExpr( PlanNodeId planNodeId, sbe::value::SlotIdGenerator* slotIdGenerator); -template <bool IsConst> +template <bool IsConst, bool IsEof = false> EvalStage makeFilter(EvalStage stage, std::unique_ptr<sbe::EExpression> filter, PlanNodeId planNodeId) { stage = stageOrLimitCoScan(std::move(stage), planNodeId); - return {sbe::makeS<sbe::FilterStage<IsConst>>( + return {sbe::makeS<sbe::FilterStage<IsConst, IsEof>>( std::move(stage.stage), std::move(filter), planNodeId), std::move(stage.outSlots)}; } @@ -199,6 +214,27 @@ EvalStage makeLoopJoin(EvalStage left, const sbe::value::SlotVector& lexicalEnvironment = {}); /** + * Creates an unwind stage and an output slot for it using the first slot in the outSlots vector of + * the inputEvalStage as the input slot to the new stage. The preserveNullAndEmptyArrays is passed + * to the UnwindStage constructor to specify the treatment of null or missing inputs. + */ +EvalStage makeUnwind(EvalStage inputEvalStage, + sbe::value::SlotIdGenerator* slotIdGenerator, + PlanNodeId planNodeId, + bool preserveNullAndEmptyArrays = true); + +/** + * Creates a branch stage with the specified condition ifExpr and creates output slots for the + * branch stage. This forwards the outputs of the thenStage to the output slots of the branchStage + * if the condition evaluates to true, and forwards the elseStage outputs if the condition is false. + */ +EvalStage makeBranch(std::unique_ptr<sbe::EExpression> ifExpr, + EvalStage thenStage, + EvalStage elseStage, + sbe::value::SlotIdGenerator* slotIdGenerator, + PlanNodeId planNodeId); + +/** * Creates traverse stage. All 'outSlots' from 'outer' argument (except for 'inField') along with * slots from the 'lexicalEnvironment' argument are passed as correlated. */ @@ -220,6 +256,14 @@ using BranchFn = std::function<std::pair<sbe::value::SlotId, EvalStage>( sbe::value::SlotIdGenerator* slotIdGenerator)>; /** + * Creates a union stage with specified branches. Each branch is passed to 'branchFn' first. If + * 'branchFn' is not set, expression from branch is simply projected to a slot. + */ +EvalExprStagePair generateUnion(std::vector<EvalExprStagePair> branches, + BranchFn branchFn, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator); +/** * Creates limit-1/union stage with specified branches. Each branch is passed to 'branchFn' first. * If 'branchFn' is not set, expression from branch is simply projected to a slot. */ diff --git a/src/mongo/db/query/sbe_stage_builder_projection.cpp b/src/mongo/db/query/sbe_stage_builder_projection.cpp index e01de77c953..78308a7a809 100644 --- a/src/mongo/db/query/sbe_stage_builder_projection.cpp +++ b/src/mongo/db/query/sbe_stage_builder_projection.cpp @@ -233,7 +233,7 @@ public: // existing 'fieldPathExpressionsTraverseStage' sub-tree. auto [outputSlot, expr, stage] = generateExpression(_context->opCtx, - node->expressionRaw(), + node->expression()->optimize().get(), std::move(_context->topLevel().fieldPathExpressionsTraverseStage), _context->slotIdGenerator, _context->frameIdGenerator, |