summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlya Berciu <alyacarina@gmail.com>2020-11-16 14:17:01 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-11-23 16:21:33 +0000
commit4cf9000c24591166f4c093f4702a522a4a62097f (patch)
tree0e4638e715b8b1051bf2d10097e1c27742af444a
parent71fb74aab300a852761e0ae3b0782c207f4aef52 (diff)
downloadmongo-4cf9000c24591166f4c093f4702a522a4a62097f.tar.gz
SERVER-51534 Support concatArrays in SBE
-rw-r--r--jstests/aggregation/bugs/server14872.js9
-rw-r--r--jstests/aggregation/expressions/concat_arrays.js154
-rw-r--r--jstests/libs/sbe_assert_error_override.js1
-rw-r--r--src/mongo/db/exec/sbe/stages/hash_agg.cpp5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp345
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp60
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h48
-rw-r--r--src/mongo/db/query/sbe_stage_builder_projection.cpp2
8 files changed, 488 insertions, 136 deletions
diff --git a/jstests/aggregation/bugs/server14872.js b/jstests/aggregation/bugs/server14872.js
index 8f4b84701da..f3e4abf85b6 100644
--- a/jstests/aggregation/bugs/server14872.js
+++ b/jstests/aggregation/bugs/server14872.js
@@ -1,14 +1,11 @@
// SERVER-14872: Aggregation expression to concatenate multiple arrays into one
-// @tags: [
-// sbe_incompatible,
-// ]
-
-// For assertErrorCode.
-load('jstests/aggregation/extras/utils.js');
(function() {
'use strict';
+load('jstests/aggregation/extras/utils.js'); // For assertErrorCode.
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
+
var coll = db.agg_concat_arrays_expr;
coll.drop();
diff --git a/jstests/aggregation/expressions/concat_arrays.js b/jstests/aggregation/expressions/concat_arrays.js
new file mode 100644
index 00000000000..37f4b6c4a72
--- /dev/null
+++ b/jstests/aggregation/expressions/concat_arrays.js
@@ -0,0 +1,154 @@
+// Confirm correctness of $concatArrays expression evaluation.
+
+(function() {
+"use strict";
+
+load("jstests/aggregation/extras/utils.js"); // For assertArrayEq.
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
+
+const coll = db.projection_expr_concat_arrays;
+coll.drop();
+
+assert.commandWorked(coll.insertOne({
+ int_arr: [1, 2, 3, 4],
+ dbl_arr: [10.0, 20.1, 20.4, 50.5],
+ nested_arr: [["an", "array"], "arr", [[], [[], "a", "b"]]],
+ str_arr: ["a", "b", "c"],
+ obj_arr: [{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}],
+ null_arr: [null, null, null],
+ one_null_arr: [null],
+ one_str_arr: ["one"],
+ empty_arr: [],
+ null_val: null,
+ str_val: "a string",
+ dbl_val: 2.0,
+ int_val: 1,
+ obj_val: {a: 1, b: "two"}
+}));
+
+function runAndAssert(operands, expectedResult) {
+ assertArrayEq({
+ actual: coll.aggregate([{$project: {f: {$concatArrays: operands}}}]).map(doc => doc.f),
+ expected: expectedResult
+ });
+}
+
+function runAndAssertNull(operands) {
+ runAndAssert(operands, [null]);
+}
+
+function runAndAssertThrows(operands) {
+ const error =
+ assert.throws(() => coll.aggregate([{$project: {f: {$concatArrays: operands}}}]).toArray());
+ assert.commandFailedWithCode(error, 28664);
+}
+
+runAndAssert(["$int_arr"], [[1, 2, 3, 4]]);
+runAndAssert([[0], "$int_arr", [5, 6, 7]], [[0, 1, 2, 3, 4, 5, 6, 7]]);
+runAndAssert(["$int_arr", "$str_arr"], [[1, 2, 3, 4, "a", "b", "c"]]);
+runAndAssert(
+ ["$obj_arr", "$obj_arr", "$null_arr"],
+ [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, {a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, null, null, null]]);
+runAndAssert(["$int_arr", "$str_arr", "$nested_arr"],
+ [[1, 2, 3, 4, "a", "b", "c", ["an", "array"], "arr", [[], [[], "a", "b"]]]]);
+runAndAssert(["$int_arr", "$obj_arr"], [[1, 2, 3, 4, {a: 1, b: 2}, {c: 3}, {d: 4, e: 5}]]);
+runAndAssert(["$obj_arr"], [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}]]);
+runAndAssert(["$obj_arr", [{o: 123, b: 1}, {y: "o", d: "a"}]],
+ [[{a: 1, b: 2}, {c: 3}, {d: 4, e: 5}, {o: 123, b: 1}, {y: "o", d: "a"}]]);
+
+// Confirm that arrays containing null can be concatenated.
+runAndAssert(["$null_arr"], [[null, null, null]]);
+runAndAssert([[null], "$null_arr"], [[null, null, null, null]]);
+runAndAssert("$one_null_arr", [[null]]);
+runAndAssert(["$null_arr", "$one_null_arr", "$int_arr", "$null_arr"],
+ [[null, null, null, null, 1, 2, 3, 4, null, null, null]]);
+
+// Test operands that form more complex expressions.
+runAndAssert([{$concatArrays: "$int_arr"}], [[1, 2, 3, 4]]);
+runAndAssert([{$concatArrays: "$int_arr"}, {$concatArrays: {$concatArrays: "$str_arr"}}],
+ [[1, 2, 3, 4, "a", "b", "c"]]);
+runAndAssert(["$str_arr", {$filter: {input: "$int_arr",
+ as: "num",
+ cond: { $and: [
+ { $gte: [ "$$num", 2 ] },
+ { $lte: [ "$$num", 3 ] }
+ ] }}}, "$int_arr"],
+ [["a", "b", "c", 2, 3, 1, 2, 3, 4]]);
+
+// Confirm that having any combination of null or missing inputs and valid inputs produces null.
+runAndAssertNull(["$int_arr", "$null_val"]);
+runAndAssertNull(["$int_arr", null]);
+runAndAssertNull([null, "$int_arr", "$str_arr"]);
+runAndAssertNull(["$int_arr", null, "$str_arr"]);
+runAndAssertNull(["$null_val", "$str_arr", "$int_arr"]);
+runAndAssertNull(["$str_arr", "$null_val", "$int_arr"]);
+runAndAssertNull(["$int_arr", "$not_a_field"]);
+runAndAssertNull(["$not_a_field", "$str_arr", "$int_arr"]);
+runAndAssertNull(["$not_a_field"]);
+runAndAssertNull(["$null_val"]);
+runAndAssertNull(["$not_a_field", "$null_val"]);
+runAndAssertNull(["$null_val", "$not_a_field"]);
+runAndAssertNull([
+ {$concatArrays: "$int_arr"},
+ null,
+ {$concatArrays: {$concatArrays: ["$obj_arr", "$str_arr"]}}
+]);
+
+// Confirm edge case where if null precedes non-array input, null is returned.
+runAndAssertNull(["$int_arr", "$null_val", "$int_val"]);
+
+//
+// Confirm error cases.
+//
+
+// Confirm concatenating non-array and non-values produces an error.
+runAndAssertThrows(["$dbl_val"]);
+runAndAssertThrows(["$str_val"]);
+runAndAssertThrows(["$int_val"]);
+runAndAssertThrows([123]);
+runAndAssertThrows(["some_val", [1, 2, 3]]);
+runAndAssertThrows(["$obj_val"]);
+runAndAssertThrows(["$int_arr", "$int_val"]);
+runAndAssertThrows(["$dbl_arr", "$dbl_val"]);
+
+// Confirm edge case where if invalid input precedes null or missing inputs, the command fails.
+runAndAssertThrows(["$int_arr", "$dbl_val", "$null_val"]);
+runAndAssertThrows(["$int_arr", "some_string_value", "$null_val"]);
+runAndAssertThrows(["$int_arr", 32]);
+runAndAssertThrows(["$dbl_val", "$null_val"]);
+runAndAssertThrows(["$int_arr", "$int_val", "$not_a_field"]);
+runAndAssertThrows(["$int_val", "$not_a_field"]);
+runAndAssertThrows(["$int_val", "$not_a_field", "$null_val"]);
+
+// Clear collection.
+assert(coll.drop());
+
+// Test case where find returns multiple documents.
+assert.commandWorked(coll.insertMany([
+ {arr1: [42, 35.0, 197865432], arr2: ["albatross", "abbacus", "alien"]},
+ {arr1: [1], arr2: ["albatross", "abbacus", "alien"]},
+ {arr1: [1, 2, 3, 4, 5, 6, 11, 12, 23], arr2: []},
+ {arr1: [], arr2: ["foo", "bar"]},
+ {arr1: [], arr2: []},
+ {arr1: [1, 2, 3, 4, 5, 6, 11, 12, 23], arr2: null},
+ {some_field: "foo"},
+]));
+runAndAssert(["$arr1", "$arr2"], [
+ [42, 35.0, 197865432, "albatross", "abbacus", "alien"],
+ [1, "albatross", "abbacus", "alien"],
+ [1, 2, 3, 4, 5, 6, 11, 12, 23],
+ ["foo", "bar"],
+ [],
+ null,
+ null
+]);
+runAndAssert(["$arr1", [1, 2, 3], "$arr2"], [
+ [42, 35.0, 197865432, 1, 2, 3, "albatross", "abbacus", "alien"],
+ [1, 1, 2, 3, "albatross", "abbacus", "alien"],
+ [1, 2, 3, 4, 5, 6, 11, 12, 23, 1, 2, 3],
+ ["foo", 1, 2, 3, "bar"],
+ [1, 2, 3],
+ null,
+ null
+]);
+}());
diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js
index ce655b77aa2..cd98fec00fe 100644
--- a/jstests/libs/sbe_assert_error_override.js
+++ b/jstests/libs/sbe_assert_error_override.js
@@ -32,6 +32,7 @@ const equivalentErrorCodesList = [
[16609, 5073101],
[16610, 4848403],
[16555, 5073102],
+ [28664, 5153400],
[28680, 4903701],
[28689, 5126701],
[28690, 5126702],
diff --git a/src/mongo/db/exec/sbe/stages/hash_agg.cpp b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
index fc97e86426d..c9f6da29462 100644
--- a/src/mongo/db/exec/sbe/stages/hash_agg.cpp
+++ b/src/mongo/db/exec/sbe/stages/hash_agg.cpp
@@ -105,6 +105,10 @@ void HashAggStage::open(bool reOpen) {
_commonStats.opens++;
_children[0]->open(reOpen);
+ if (reOpen) {
+ _ht.clear();
+ }
+
while (_children[0]->getNext() == PlanState::ADVANCED) {
value::MaterializedRow key{_inKeyAccessors.size()};
// Copy keys in order to do the lookup.
@@ -161,6 +165,7 @@ const SpecificStats* HashAggStage::getSpecificStats() const {
void HashAggStage::close() {
_commonStats.closes++;
+ _ht.clear();
}
std::vector<DebugPrinter::Block> HashAggStage::debugPrint() const {
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 42b7a086e90..731b30e86d9 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -36,6 +36,7 @@
#include "mongo/db/exec/sbe/stages/branch.h"
#include "mongo/db/exec/sbe/stages/co_scan.h"
#include "mongo/db/exec/sbe/stages/filter.h"
+#include "mongo/db/exec/sbe/stages/hash_agg.h"
#include "mongo/db/exec/sbe/stages/limit_skip.h"
#include "mongo/db/exec/sbe/stages/loop_join.h"
#include "mongo/db/exec/sbe/stages/project.h"
@@ -182,12 +183,11 @@ std::pair<sbe::value::SlotId, EvalStage> generateTraverseHelper(
std::move(inputStage),
planNodeId,
fieldSlot,
- sbe::makeE<sbe::EFunction>(
- "getField"sv,
- sbe::makeEs(sbe::makeE<sbe::EVariable>(inputSlot), sbe::makeE<sbe::EConstant>([&]() {
- auto fieldName = fp.getFieldName(level);
- return std::string_view{fieldName.rawData(), fieldName.size()};
- }()))));
+ makeFunction(
+ "getField"sv, sbe::makeE<sbe::EVariable>(inputSlot), sbe::makeE<sbe::EConstant>([&]() {
+ auto fieldName = fp.getFieldName(level);
+ return std::string_view{fieldName.rawData(), fieldName.size()};
+ }())));
EvalStage innerBranch;
if (level == fp.getPathLength() - 1) {
@@ -283,9 +283,7 @@ void generateStringCaseConversionExpression(ExpressionVisitorContext* _context,
auto caseConversionExpr = sbe::makeE<sbe::EIf>(
std::move(checkValidTypeExpr),
- sbe::makeE<sbe::EFunction>(caseConversionFunction,
- sbe::makeEs(sbe::makeE<sbe::EFunction>(
- "coerceToString", sbe::makeEs(inputRef.clone())))),
+ makeFunction(caseConversionFunction, makeFunction("coerceToString", inputRef.clone())),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{5066300},
str::stream() << "$" << caseConversionFunction
<< " input type is not supported"));
@@ -311,16 +309,14 @@ void buildArrayAccessByConstantIndex(ExpressionVisitorContext* context,
auto indexExpr = sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32,
sbe::value::bitcastFrom<int32_t>(index));
- auto argumentIsNotArray =
- makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(arrayRef.clone())));
+ auto argumentIsNotArray = makeNot(makeFunction("isArray", arrayRef.clone()));
auto resultExpr = buildMultiBranchConditional(
CaseValuePair{generateNullOrMissing(arrayRef),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
CaseValuePair{std::move(argumentIsNotArray),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{5126704},
exprName + " argument must be an array")},
- sbe::makeE<sbe::EFunction>("getElement",
- sbe::makeEs(arrayRef.clone(), std::move(indexExpr))));
+ makeFunction("getElement", arrayRef.clone(), std::move(indexExpr)));
context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(resultExpr)));
@@ -364,7 +360,9 @@ public:
void visit(ExpressionCoerceToBool* expr) final {}
void visit(ExpressionCompare* expr) final {}
void visit(ExpressionConcat* expr) final {}
- void visit(ExpressionConcatArrays* expr) final {}
+ void visit(ExpressionConcatArrays* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionCond* expr) final {
_context->evalStack.emplaceFrame(EvalStage{});
}
@@ -519,7 +517,9 @@ public:
void visit(ExpressionCoerceToBool* expr) final {}
void visit(ExpressionCompare* expr) final {}
void visit(ExpressionConcat* expr) final {}
- void visit(ExpressionConcatArrays* expr) final {}
+ void visit(ExpressionConcatArrays* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionCond* expr) final {
_context->evalStack.emplaceFrame(EvalStage{});
}
@@ -724,7 +724,7 @@ public:
CaseValuePair{generateLongLongMinCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903701},
"can't take $abs of long long min")},
- sbe::makeE<sbe::EFunction>("abs", sbe::makeEs(inputRef.clone())));
+ makeFunction("abs", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(absExpr)));
@@ -739,12 +739,10 @@ public:
sbe::EVariable var{frameId, slotId};
return sbe::makeE<sbe::EPrimBinary>(
sbe::EPrimBinary::logicAnd,
- sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(var.clone()))),
- sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(var.clone()))));
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isNumber", var.clone())),
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isDate", var.clone())));
};
if (arity == 2) {
@@ -767,10 +765,9 @@ public:
ErrorCodes::Error{4974201},
"only numbers and dates are allowed in an $add expression"),
sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EPrimBinary>(
- sbe::EPrimBinary::logicAnd,
- sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(lhsVar.clone())),
- sbe::makeE<sbe::EFunction>("isDate", sbe::makeEs(rhsVar.clone()))),
+ sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicAnd,
+ makeFunction("isDate", lhsVar.clone()),
+ makeFunction("isDate", rhsVar.clone())),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202},
"only one date allowed in an $add expression"),
sbe::makeE<sbe::EPrimBinary>(
@@ -863,7 +860,7 @@ public:
sbe::EVariable convertedIndexRef{frameId, 0};
auto inExpression = sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(convertedIndexRef.clone())),
+ makeFunction("exists", convertedIndexRef.clone()),
convertedIndexRef.clone(),
sbe::makeE<sbe::EFail>(
ErrorCodes::Error{5126703},
@@ -876,8 +873,7 @@ public:
sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr,
generateNullOrMissing(arrayRef),
generateNullOrMissing(indexRef));
- auto firstArgumentIsNotArray =
- makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(arrayRef.clone())));
+ auto firstArgumentIsNotArray = makeNot(makeFunction("isArray", arrayRef.clone()));
auto secondArgumentIsNotNumeric = generateNonNumericCheck(indexRef);
auto arrayElemAtExpr = buildMultiBranchConditional(
CaseValuePair{std::move(anyOfArgumentsIsNullish),
@@ -888,8 +884,7 @@ public:
CaseValuePair{std::move(secondArgumentIsNotNumeric),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{5126702},
"$arrayElemAt second argument must be a number")},
- sbe::makeE<sbe::EFunction>("getElement",
- sbe::makeEs(arrayRef.clone(), std::move(int32Index))));
+ makeFunction("getElement", arrayRef.clone(), std::move(int32Index)));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(arrayElemAtExpr)));
@@ -923,7 +918,7 @@ public:
CaseValuePair{generateNonObjectCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{5043001},
"$bsonSize requires a document input")},
- sbe::makeE<sbe::EFunction>("bsonSize", sbe::makeEs(inputRef.clone())));
+ makeFunction("bsonSize", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(bsonSizeExpr)));
@@ -939,7 +934,7 @@ public:
CaseValuePair{generateNonNumericCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903702},
"$ceil only supports numeric types")},
- sbe::makeE<sbe::EFunction>("ceil", sbe::makeEs(inputRef.clone())));
+ makeFunction("ceil", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ceilExpr)));
@@ -997,13 +992,13 @@ public:
// will also evaluate to "Nothing." MQL comparisons, however, treat "Nothing" as if it is a
// value that is less than everything other than MinKey. (Notably, two expressions that
// evaluate to "Nothing" are considered equal to each other.)
- auto nothingFallbackCmp = sbe::makeE<sbe::EPrimBinary>(
- comparisonOperator,
- sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(lhsRef.clone())),
- sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(rhsRef.clone())));
+ auto nothingFallbackCmp =
+ sbe::makeE<sbe::EPrimBinary>(comparisonOperator,
+ makeFunction("exists", lhsRef.clone()),
+ makeFunction("exists", rhsRef.clone()));
- auto cmpWithFallback = sbe::makeE<sbe::EFunction>(
- "fillEmpty", sbe::makeEs(std::move(cmp), std::move(nothingFallbackCmp)));
+ auto cmpWithFallback =
+ makeFunction("fillEmpty", std::move(cmp), std::move(nothingFallbackCmp));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(operands), std::move(cmpWithFallback)));
@@ -1023,8 +1018,7 @@ public:
sbe::EVariable var(frameId, slot);
binds.push_back(_context->popExpr());
checkNullArg.push_back(generateNullOrMissing(frameId, slot));
- checkStringArg.push_back(
- sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(var.clone())));
+ checkStringArg.push_back(makeFunction("isString", var.clone()));
argVars.push_back(var.clone());
}
std::reverse(std::begin(binds), std::end(binds));
@@ -1060,7 +1054,132 @@ public:
}
void visit(ExpressionConcatArrays* expr) final {
- unsupportedExpression(expr->getOpName());
+ // Pop eval frames pushed by pre and in visitors off the stack.
+ std::vector<EvalExprStagePair> branches;
+ auto numChildren = expr->getChildren().size();
+ branches.reserve(numChildren);
+ for (size_t idx = 0; idx < numChildren; ++idx) {
+ auto [branchExpr, branchEvalStage] = _context->popFrame();
+ branches.emplace_back(std::move(branchExpr), std::move(branchEvalStage));
+ }
+ std::reverse(branches.begin(), branches.end());
+
+ auto getUnionOutputSlot = [](EvalExpr& unionEvalExpr) {
+ auto slot = *(unionEvalExpr.getSlot());
+ invariant(slot);
+ return slot;
+ };
+
+ auto makeNullLimitCoscanTree = [&]() {
+ auto outputSlot = _context->slotIdGenerator->generate();
+ auto nullEvalStage =
+ makeProject({makeLimitCoScanTree(_context->planNodeId), sbe::makeSV()},
+ _context->planNodeId,
+ outputSlot,
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0));
+ return EvalExprStagePair{outputSlot, std::move(nullEvalStage)};
+ };
+
+ // Build a union stage to consolidate array input branches into a stream.
+ auto [unionEvalExpr, unionEvalStage] =
+ generateUnion(std::move(branches), {}, _context->planNodeId, _context->slotIdGenerator);
+ auto unionSlot = getUnionOutputSlot(unionEvalExpr);
+ sbe::EVariable unionVar{unionSlot};
+
+ // Filter stage to EFail if an element is not an array, null, or missing, and EOF if an
+ // element is null or missing: not(isNullOrMissing) && (isArray || EFail).
+ auto filterExpr = sbe::makeE<sbe::EPrimBinary>(
+ sbe::EPrimBinary::logicAnd,
+ makeNot(generateNullOrMissing(unionVar)),
+ sbe::makeE<sbe::EPrimBinary>(
+ sbe::EPrimBinary::logicOr,
+ makeFunction("isArray", unionVar.clone()),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{5153400},
+ "$concatArrays only supports arrays")));
+ auto filter = makeFilter<false, true>(
+ std::move(unionEvalStage), std::move(filterExpr), _context->planNodeId);
+
+ // Create a union stage to replace any values filtered out by the previous stage with null.
+ // For example, [a, b, null, c, d] would become [a, b, null].
+ std::vector<EvalExprStagePair> unionWithNullBranches;
+ unionWithNullBranches.emplace_back(sbe::makeE<sbe::EVariable>(unionSlot),
+ std::move(filter));
+ unionWithNullBranches.emplace_back(makeNullLimitCoscanTree());
+ auto [unionWithNullExpr, unionWithNullStage] = generateUnion(
+ std::move(unionWithNullBranches), {}, _context->planNodeId, _context->slotIdGenerator);
+ auto unionWithNullSlot = getUnionOutputSlot(unionWithNullExpr);
+
+ // Create a limit stage to EOF once numChildren results have been obtained.
+ auto limitNumChildren =
+ makeLimitTree(std::move(unionWithNullStage.stage), _context->planNodeId, numChildren);
+
+ // Create a group stage to aggregate elements into a single array.
+ auto addToArrayExpr =
+ makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unionWithNullSlot));
+ auto groupSlot = _context->slotIdGenerator->generate();
+ auto groupStage =
+ sbe::makeS<sbe::HashAggStage>(std::move(limitNumChildren),
+ sbe::makeSV(),
+ sbe::makeEM(groupSlot, std::move(addToArrayExpr)),
+ _context->planNodeId);
+ EvalStage groupEvalStage = {std::move(groupStage), sbe::makeSV(groupSlot)};
+
+ // Build subtree to handle nulls. If an input is null, return null. Otherwise, unwind the
+ // input twice, and concatenate it into an array using addToArray. This is necessary to
+ // implement the MQL behavior where one null or missing input results in a null output.
+
+ // Create two unwind stages to unwind the array that was built from inputs
+ // and unwind each input array into its constituent elements. We need a limit 1/coscan stage
+ // here to call getNext() on, but we use the output slot of groupStage to obtain the array
+ // of inputs.
+ auto unwindEvalStage = makeUnwind(
+ makeUnwind({makeLimitCoScanStage(_context->planNodeId).stage, sbe::makeSV(groupSlot)},
+ _context->slotIdGenerator,
+ _context->planNodeId),
+ _context->slotIdGenerator,
+ _context->planNodeId);
+ auto unwindSlot = unwindEvalStage.outSlots.front();
+
+ // Create a group stage to append all streamed elements into one array. This is the final
+ // output when the input consists entirely of arrays.
+ auto finalAddToArrayExpr =
+ makeFunction("addToArray", sbe::makeE<sbe::EVariable>(unwindSlot));
+ auto finalGroupSlot = _context->slotIdGenerator->generate();
+ auto finalGroupStage = sbe::makeS<sbe::HashAggStage>(
+ std::move(unwindEvalStage.stage),
+ sbe::makeSV(),
+ sbe::makeEM(finalGroupSlot, std::move(finalAddToArrayExpr)),
+ _context->planNodeId);
+
+ // Create a branch stage to select between the branch that produces one null if any eleemnts
+ // in the original input were null or missing, or otherwise select the branch that unwinds
+ // and concatenates elements into the output array.
+ auto [nullExpr, nullStage] = makeNullLimitCoscanTree();
+ auto nullIsMemberExpr =
+ makeFunction("isMember",
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0),
+ sbe::makeE<sbe::EVariable>(groupSlot));
+ auto branchNullEvalStage =
+ makeBranch(std::move(nullIsMemberExpr),
+ std::move(nullStage),
+ {std::move(finalGroupStage), sbe::makeSV(finalGroupSlot)},
+ _context->slotIdGenerator,
+ _context->planNodeId);
+ auto branchSlot = branchNullEvalStage.outSlots.front();
+
+ // Create nlj to connect outer group with inner branch that handles null input.
+ auto nljStage = makeLoopJoin(std::move(groupEvalStage),
+ std::move(branchNullEvalStage),
+ _context->planNodeId,
+ _context->getLexicalEnvironment());
+
+ // Top level nlj to inject input slots.
+ auto finalNljStage = makeLoopJoin(_context->extractCurrentEvalStage(),
+ std::move(nljStage),
+ _context->planNodeId,
+ _context->getLexicalEnvironment());
+
+ _context->pushExpr(sbe::makeE<sbe::EVariable>(branchSlot), std::move(finalNljStage));
}
void visit(ExpressionCond* expr) final {
visitConditionalExpression(expr);
@@ -1174,19 +1293,16 @@ public:
sbe::makeE<sbe::EIf>(
sbe::makeE<sbe::EPrimBinary>(
sbe::EPrimBinary::logicOr,
- sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("exists",
- sbe::makeEs(outerSlotRef.clone()))),
- sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(outerSlotRef.clone()))),
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("exists", outerSlotRef.clone())),
+ makeFunction("isNull", outerSlotRef.clone())),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0),
sbe::makeE<sbe::ELocalBind>(
innerFrameId,
sbe::makeEs(sbe::makeE<sbe::ENumericConvert>(
outerSlotRef.clone(), sbe::value::TypeTags::NumberInt64)),
sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EFunction>("exists",
- sbe::makeEs(convertedFieldRef.clone())),
+ makeFunction("exists", convertedFieldRef.clone()),
convertedFieldRef.clone(),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4848979},
str::stream()
@@ -1307,7 +1423,7 @@ public:
tzFrameId,
sbe::makeEs(std::move(eTimezone)),
sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(timeZoneRef.clone())),
+ makeFunction("isString", timeZoneRef.clone()),
timezoneRef.clone(),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4848980},
str::stream()
@@ -1342,17 +1458,16 @@ public:
// runtime environment so we pass the corresponding slot to the datePartsWeekYear and
// dateParts functions as a variable.
auto timeZoneDBSlot = _context->runtimeEnvironment->getSlot("timeZoneDB"_sd);
- auto computeDate =
- sbe::makeE<sbe::EFunction>(isIsoWeekYear ? "datePartsWeekYear" : "dateParts",
- sbe::makeEs(sbe::makeE<sbe::EVariable>(timeZoneDBSlot),
- yearRef.clone(),
- monthRef.clone(),
- dayRef.clone(),
- hourRef.clone(),
- minRef.clone(),
- secRef.clone(),
- millisecRef.clone(),
- timeZoneRef.clone()));
+ auto computeDate = makeFunction(isIsoWeekYear ? "datePartsWeekYear" : "dateParts",
+ sbe::makeE<sbe::EVariable>(timeZoneDBSlot),
+ yearRef.clone(),
+ monthRef.clone(),
+ dayRef.clone(),
+ hourRef.clone(),
+ minRef.clone(),
+ secRef.clone(),
+ millisecRef.clone(),
+ timeZoneRef.clone());
using iterPair_t = std::vector<std::pair<std::unique_ptr<sbe::EExpression>,
std::unique_ptr<sbe::EExpression>>>::iterator;
@@ -1449,19 +1564,17 @@ public:
CaseValuePair{generateNullOrMissing(frameId, 1),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
CaseValuePair{
- sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("isString", sbe::makeEs(timezoneRef.clone()))),
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isString", timezoneRef.clone())),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4997701},
"$dateToParts timezone must be a string")},
CaseValuePair{
sbe::makeE<sbe::EPrimUnary>(
sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>(
- "isTimezone",
- sbe::makeEs(sbe::makeE<sbe::EVariable>(
- _context->runtimeEnvironment->getSlot("timeZoneDB"_sd)),
- timezoneRef.clone()))),
+ makeFunction("isTimezone",
+ sbe::makeE<sbe::EVariable>(
+ _context->runtimeEnvironment->getSlot("timeZoneDB"_sd)),
+ timezoneRef.clone())),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4997704},
"$dateToParts timezone must be a valid timezone")},
CaseValuePair{generateNullOrMissing(frameId, 2),
@@ -1497,10 +1610,9 @@ public:
sbe::EVariable lhsRef{frameId, 0};
sbe::EVariable rhsRef{frameId, 1};
- auto checkIsNumber = sbe::makeE<sbe::EPrimBinary>(
- sbe::EPrimBinary::logicAnd,
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(lhsRef.clone())),
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(rhsRef.clone())));
+ auto checkIsNumber = sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicAnd,
+ makeFunction("isNumber", lhsRef.clone()),
+ makeFunction("isNumber", rhsRef.clone()));
auto checkIsNullOrMissing = sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr,
generateNullOrMissing(lhsRef),
@@ -1529,7 +1641,7 @@ public:
CaseValuePair{generateNonNumericCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903703},
"$exp only supports numeric types")},
- sbe::makeE<sbe::EFunction>("exp", sbe::makeEs(inputRef.clone())));
+ makeFunction("exp", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(expExpr)));
@@ -1605,10 +1717,10 @@ public:
auto binds = sbe::makeEs(std::move(input));
sbe::EVariable inputRef(frameId, 0);
- auto inputIsArrayOrNullish = sbe::makeE<sbe::EPrimBinary>(
- sbe::EPrimBinary::logicOr,
- generateNullOrMissing(inputRef),
- sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(inputRef.clone())));
+ auto inputIsArrayOrNullish =
+ sbe::makeE<sbe::EPrimBinary>(sbe::EPrimBinary::logicOr,
+ generateNullOrMissing(inputRef),
+ makeFunction("isArray", inputRef.clone()));
auto checkInputArrayType =
sbe::makeE<sbe::EIf>(std::move(inputIsArrayOrNullish),
inputRef.clone(),
@@ -1673,10 +1785,9 @@ public:
// If input array is null or missing, 'in' stage of traverse will return EOF. In this case
// traverse sets output slot (filteredArraySlot) to Nothing. We replace it with Null to
// match $filter expression behaviour.
- auto result = sbe::makeE<sbe::EFunction>(
- "fillEmpty",
- sbe::makeEs(sbe::makeE<sbe::EVariable>(filteredArraySlot),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)));
+ auto result = makeFunction("fillEmpty",
+ sbe::makeE<sbe::EVariable>(filteredArraySlot),
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0));
_context->pushExpr(std::move(result), std::move(traverseStage));
}
@@ -1691,7 +1802,7 @@ public:
CaseValuePair{generateNonNumericCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903704},
"$floor only supports numeric types")},
- sbe::makeE<sbe::EFunction>("floor", sbe::makeEs(inputRef.clone())));
+ makeFunction("floor", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr)));
@@ -1732,11 +1843,11 @@ public:
auto binds = sbe::makeEs(_context->popExpr());
sbe::EVariable inputRef(frameId, 0);
- auto exprIsNum = sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(inputRef.clone())),
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone())),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean,
- sbe::value::bitcastFrom<bool>(false)));
+ auto exprIsNum =
+ sbe::makeE<sbe::EIf>(makeFunction("exists", inputRef.clone()),
+ makeFunction("isNumber", inputRef.clone()),
+ sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean,
+ sbe::value::bitcastFrom<bool>(false)));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(exprIsNum)));
@@ -1786,7 +1897,7 @@ public:
CaseValuePair{generateNonPositiveCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903706},
"$ln's argument must be a positive number")},
- sbe::makeE<sbe::EFunction>("ln", sbe::makeEs(inputRef.clone())));
+ makeFunction("ln", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr)));
@@ -1813,7 +1924,7 @@ public:
CaseValuePair{generateNonPositiveCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903708},
"$log10's argument must be a positive number")},
- sbe::makeE<sbe::EFunction>("log10", sbe::makeEs(inputRef.clone())));
+ makeFunction("log10", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(log10Expr)));
@@ -1845,8 +1956,7 @@ public:
sbe::EPrimUnary::logicNot,
sbe::makeE<sbe::ETypeMatch>(
lhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)))),
- sbe::makeE<sbe::EFunction>(
- "fillEmpty", sbe::makeEs(std::move(numericConvert32), rhsVar.clone()))},
+ makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())},
rhsVar.clone());
auto modExpr = buildMultiBranchConditional(
@@ -1859,7 +1969,7 @@ public:
generateNonNumericCheck(rhsVar)),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000},
"$mod only supports numeric types")},
- sbe::makeE<sbe::EFunction>("mod", sbe::makeEs(lhsVar.clone(), std::move(rhsExpr))));
+ makeFunction("mod", lhsVar.clone(), std::move(rhsExpr)));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(modExpr)));
@@ -1884,8 +1994,7 @@ public:
variables.push_back(currentVariable.clone());
checkExprsNull.push_back(generateNullOrMissing(currentVariable));
- checkExprsNumber.push_back(
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(currentVariable.clone())));
+ checkExprsNumber.push_back(makeFunction("isNumber", currentVariable.clone()));
}
// At this point 'binds' vector contains arguments of $multiply expression in the reversed
@@ -2076,7 +2185,7 @@ public:
generateNegativeCheck(inputRef),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4903710},
"$sqrt's argument must be greater than or equal to 0")},
- sbe::makeE<sbe::EFunction>("sqrt", sbe::makeEs(inputRef.clone())));
+ makeFunction("sqrt", inputRef.clone()));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(lnExpr)));
@@ -2438,15 +2547,13 @@ private:
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998200},
str::stream() << "$" << exprName.toString()
<< " timezone must be a string")},
- CaseValuePair{sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>(
- "isTimezone",
- sbe::makeEs(sbe::makeE<sbe::EVariable>(timeZoneDBSlot),
- timezoneRef.clone()))),
- sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998201},
- str::stream()
- << "$" << exprName.toString()
+ CaseValuePair{
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isTimezone",
+ sbe::makeE<sbe::EVariable>(timeZoneDBSlot),
+ timezoneRef.clone())),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{4998201},
+ str::stream() << "$" << exprName.toString()
<< " timezone must be a valid timezone")},
CaseValuePair{generateNullOrMissing(dateRef),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
@@ -2474,12 +2581,12 @@ private:
auto genericTrignomentricExpr = sbe::makeE<sbe::EIf>(
generateNullOrMissing(frameId, 0),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0),
- sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone())),
- sbe::makeE<sbe::EFunction>(exprName.toString(), sbe::makeEs(inputRef.clone())),
- sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995501},
- str::stream() << "$" << exprName.toString()
- << " supports only numeric types")));
+ sbe::makeE<sbe::EIf>(makeFunction("isNumber", inputRef.clone()),
+ makeFunction(exprName.toString(), inputRef.clone()),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995501},
+ str::stream()
+ << "$" << exprName.toString()
+ << " supports only numeric types")));
_context->pushExpr(sbe::makeE<sbe::ELocalBind>(
frameId, std::move(binds), std::move(genericTrignomentricExpr)));
@@ -2517,15 +2624,14 @@ private:
generateNullOrMissing(frameId, 0),
sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0),
sbe::makeE<sbe::EIf>(
- sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(inputRef.clone()))),
+ sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isNumber", inputRef.clone())),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995502},
str::stream() << "$" << exprName.toString()
<< " supports only numeric types"),
sbe::makeE<sbe::EIf>(
std::move(checkBounds),
- sbe::makeE<sbe::EFunction>(exprName.toString(), sbe::makeEs(inputRef.clone())),
+ makeFunction(exprName.toString(), inputRef.clone()),
sbe::makeE<sbe::EFail>(ErrorCodes::Error{4995503},
str::stream() << "Cannot apply $" << exprName.toString()
<< ", value must be in "
@@ -2661,7 +2767,7 @@ private:
auto generateNotArray = [frameId](const sbe::value::SlotId slotId) {
sbe::EVariable var{frameId, slotId};
- return makeNot(sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(var.clone())));
+ return makeNot(makeFunction("isArray", var.clone()));
};
std::vector<std::unique_ptr<sbe::EExpression>> binds;
@@ -2841,10 +2947,9 @@ std::unique_ptr<sbe::EExpression> generateCoerceToBoolExpression(sbe::EVariable
// If any of these are false, the branch is considered false for the purposes of the
// any logical expression.
- auto checkExists = sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(branchRef.clone()));
- auto checkNotNull = sbe::makeE<sbe::EPrimUnary>(
- sbe::EPrimUnary::logicNot,
- sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(branchRef.clone())));
+ auto checkExists = makeFunction("exists", branchRef.clone());
+ auto checkNotNull = sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot,
+ makeFunction("isNull", branchRef.clone()));
auto checkNotFalse = makeNeqCheck(sbe::makeE<sbe::EConstant>(
sbe::value::TypeTags::Boolean, sbe::value::bitcastFrom<bool>(false)));
auto checkNotZero = makeNeqCheck(sbe::makeE<sbe::EConstant>(
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index 77a58529d35..7e7c45ded2a 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -32,6 +32,7 @@
#include "mongo/db/query/sbe_stage_builder_helpers.h"
#include "mongo/db/exec/sbe/expressions/expression.h"
+#include "mongo/db/exec/sbe/stages/branch.h"
#include "mongo/db/exec/sbe/stages/co_scan.h"
#include "mongo/db/exec/sbe/stages/limit_skip.h"
#include "mongo/db/exec/sbe/stages/loop_join.h"
@@ -126,6 +127,12 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(
return defaultCase;
}
+std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage,
+ PlanNodeId planNodeId,
+ long long limit) {
+ return sbe::makeS<sbe::LimitSkipStage>(std::move(inputStage), limit, boost::none, planNodeId);
+}
+
std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit) {
return sbe::makeS<sbe::LimitSkipStage>(
sbe::makeS<sbe::CoScanStage>(planNodeId), limit, boost::none, planNodeId);
@@ -202,6 +209,36 @@ EvalStage makeLoopJoin(EvalStage left,
std::move(outSlots)};
}
+EvalStage makeUnwind(EvalStage inputEvalStage,
+ sbe::value::SlotIdGenerator* slotIdGenerator,
+ PlanNodeId planNodeId,
+ bool preserveNullAndEmptyArrays) {
+ auto unwindSlot = slotIdGenerator->generate();
+ auto unwindStage = sbe::makeS<sbe::UnwindStage>(std::move(inputEvalStage.stage),
+ inputEvalStage.outSlots.front(),
+ unwindSlot,
+ slotIdGenerator->generate(),
+ preserveNullAndEmptyArrays,
+ planNodeId);
+ return {std::move(unwindStage), sbe::makeSV(unwindSlot)};
+}
+
+EvalStage makeBranch(std::unique_ptr<sbe::EExpression> ifExpr,
+ EvalStage thenStage,
+ EvalStage elseStage,
+ sbe::value::SlotIdGenerator* slotIdGenerator,
+ PlanNodeId planNodeId) {
+ auto outSlots = slotIdGenerator->generateMultiple(thenStage.outSlots.size());
+ auto branchStage = sbe::makeS<sbe::BranchStage>(std::move(thenStage.stage),
+ std::move(elseStage.stage),
+ std::move(ifExpr),
+ thenStage.outSlots,
+ elseStage.outSlots,
+ outSlots,
+ planNodeId);
+ return {std::move(branchStage), std::move(outSlots)};
+}
+
EvalStage makeTraverse(EvalStage outer,
EvalStage inner,
sbe::value::SlotId inField,
@@ -238,10 +275,10 @@ EvalStage makeTraverse(EvalStage outer,
std::move(outSlots)};
}
-EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branches,
- BranchFn branchFn,
- PlanNodeId planNodeId,
- sbe::value::SlotIdGenerator* slotIdGenerator) {
+EvalExprStagePair generateUnion(std::vector<EvalExprStagePair> branches,
+ BranchFn branchFn,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
std::vector<std::unique_ptr<sbe::PlanStage>> stages;
std::vector<sbe::value::SlotVector> inputs;
stages.reserve(branches.size());
@@ -266,13 +303,22 @@ EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branc
auto outputSlot = slotIdGenerator->generate();
auto unionStage = sbe::makeS<sbe::UnionStage>(
std::move(stages), std::move(inputs), sbe::makeSV(outputSlot), planNodeId);
- EvalStage outputStage = {
- sbe::makeS<sbe::LimitSkipStage>(std::move(unionStage), 1, boost::none, planNodeId),
- sbe::makeSV(outputSlot)};
+ EvalStage outputStage{std::move(unionStage), sbe::makeSV(outputSlot)};
return {outputSlot, std::move(outputStage)};
}
+EvalExprStagePair generateSingleResultUnion(std::vector<EvalExprStagePair> branches,
+ BranchFn branchFn,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ auto [unionEvalExpr, unionEvalStage] =
+ generateUnion(std::move(branches), std::move(branchFn), planNodeId, slotIdGenerator);
+ return {std::move(unionEvalExpr),
+ EvalStage{makeLimitTree(std::move(unionEvalStage.stage), planNodeId),
+ std::move(unionEvalStage.outSlots)}};
+}
+
EvalExprStagePair generateShortCircuitingLogicalOp(sbe::EPrimBinary::Op logicOp,
std::vector<EvalExprStagePair> branches,
PlanNodeId planNodeId,
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index 6ab645883bd..454665e65a7 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -124,6 +124,13 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(
std::unique_ptr<sbe::EExpression> defaultCase);
/**
+ * Insert a limit stage on top of the 'input' stage.
+ */
+std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage,
+ PlanNodeId planNodeId,
+ long long limit = 1);
+
+/**
* Create tree consisting of coscan stage followed by limit stage.
*/
std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit = 1);
@@ -150,6 +157,14 @@ std::unique_ptr<sbe::EExpression> makeNot(std::unique_ptr<sbe::EExpression> e);
std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpression> e);
/**
+ * Creates an EFunction expression with the given name and arguments.
+ */
+template <typename... Args>
+inline std::unique_ptr<sbe::EExpression> makeFunction(std::string_view name, Args&&... args) {
+ return sbe::makeE<sbe::EFunction>(name, sbe::makeEs(std::forward<Args>(args)...));
+}
+
+/**
* If given 'EvalExpr' already contains a slot, simply returns it. Otherwise, allocates a new slot
* and creates project stage to assign expression to this new slot. After that, new slot and project
* stage are returned.
@@ -160,13 +175,13 @@ std::pair<sbe::value::SlotId, EvalStage> projectEvalExpr(
PlanNodeId planNodeId,
sbe::value::SlotIdGenerator* slotIdGenerator);
-template <bool IsConst>
+template <bool IsConst, bool IsEof = false>
EvalStage makeFilter(EvalStage stage,
std::unique_ptr<sbe::EExpression> filter,
PlanNodeId planNodeId) {
stage = stageOrLimitCoScan(std::move(stage), planNodeId);
- return {sbe::makeS<sbe::FilterStage<IsConst>>(
+ return {sbe::makeS<sbe::FilterStage<IsConst, IsEof>>(
std::move(stage.stage), std::move(filter), planNodeId),
std::move(stage.outSlots)};
}
@@ -199,6 +214,27 @@ EvalStage makeLoopJoin(EvalStage left,
const sbe::value::SlotVector& lexicalEnvironment = {});
/**
+ * Creates an unwind stage and an output slot for it using the first slot in the outSlots vector of
+ * the inputEvalStage as the input slot to the new stage. The preserveNullAndEmptyArrays is passed
+ * to the UnwindStage constructor to specify the treatment of null or missing inputs.
+ */
+EvalStage makeUnwind(EvalStage inputEvalStage,
+ sbe::value::SlotIdGenerator* slotIdGenerator,
+ PlanNodeId planNodeId,
+ bool preserveNullAndEmptyArrays = true);
+
+/**
+ * Creates a branch stage with the specified condition ifExpr and creates output slots for the
+ * branch stage. This forwards the outputs of the thenStage to the output slots of the branchStage
+ * if the condition evaluates to true, and forwards the elseStage outputs if the condition is false.
+ */
+EvalStage makeBranch(std::unique_ptr<sbe::EExpression> ifExpr,
+ EvalStage thenStage,
+ EvalStage elseStage,
+ sbe::value::SlotIdGenerator* slotIdGenerator,
+ PlanNodeId planNodeId);
+
+/**
* Creates traverse stage. All 'outSlots' from 'outer' argument (except for 'inField') along with
* slots from the 'lexicalEnvironment' argument are passed as correlated.
*/
@@ -220,6 +256,14 @@ using BranchFn = std::function<std::pair<sbe::value::SlotId, EvalStage>(
sbe::value::SlotIdGenerator* slotIdGenerator)>;
/**
+ * Creates a union stage with specified branches. Each branch is passed to 'branchFn' first. If
+ * 'branchFn' is not set, expression from branch is simply projected to a slot.
+ */
+EvalExprStagePair generateUnion(std::vector<EvalExprStagePair> branches,
+ BranchFn branchFn,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator);
+/**
* Creates limit-1/union stage with specified branches. Each branch is passed to 'branchFn' first.
* If 'branchFn' is not set, expression from branch is simply projected to a slot.
*/
diff --git a/src/mongo/db/query/sbe_stage_builder_projection.cpp b/src/mongo/db/query/sbe_stage_builder_projection.cpp
index e01de77c953..78308a7a809 100644
--- a/src/mongo/db/query/sbe_stage_builder_projection.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_projection.cpp
@@ -233,7 +233,7 @@ public:
// existing 'fieldPathExpressionsTraverseStage' sub-tree.
auto [outputSlot, expr, stage] =
generateExpression(_context->opCtx,
- node->expressionRaw(),
+ node->expression()->optimize().get(),
std::move(_context->topLevel().fieldPathExpressionsTraverseStage),
_context->slotIdGenerator,
_context->frameIdGenerator,