summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml1
-rw-r--r--buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml1
-rw-r--r--jstests/aggregation/expressions/add.js31
-rw-r--r--jstests/aggregation/expressions/split.js46
-rw-r--r--jstests/core/projection_expr_mod.js22
-rw-r--r--jstests/libs/sbe_assert_error_override.js8
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp486
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp13
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h6
17 files changed, 494 insertions, 128 deletions
diff --git a/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml
index aa9e84dbc1b..098952c993e 100644
--- a/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml
@@ -159,6 +159,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml
index a2d14686ceb..649bc2b3891 100644
--- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml
@@ -174,6 +174,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml
index b61aed35737..afc436e4803 100644
--- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml
@@ -169,6 +169,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml
index 4fe4a9b5343..5846ec7162a 100644
--- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml
@@ -170,6 +170,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml b/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml
index 1cb8898a421..623cbe2d10b 100644
--- a/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml
+++ b/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml
@@ -182,6 +182,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml
index 19b3237d99b..28276078610 100644
--- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml
@@ -122,6 +122,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml
index 658df1cca30..519acfb6f2e 100644
--- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml
@@ -110,6 +110,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml
index 13227f2f413..25b4e55b041 100644
--- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml
@@ -109,6 +109,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml
index 10436896401..812df6897f8 100644
--- a/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml
@@ -146,6 +146,7 @@ selector:
- jstests/core/updatel.js
- jstests/core/write_result.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
# Trick for bypassing mongo shell validation in the test doesn't work because txn_override
# retry logic will hit the shell validation.
diff --git a/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml
index 5c813fd7a53..86761f02edf 100644
--- a/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml
+++ b/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml
@@ -227,6 +227,7 @@ selector:
- jstests/core/field_name_validation.js
- jstests/core/insert_illegal_doc.js
- jstests/core/positional_projection.js
+ - jstests/core/projection_expr_mod.js
- jstests/core/push_sort.js
- jstests/core/update_dbref.js
diff --git a/jstests/aggregation/expressions/add.js b/jstests/aggregation/expressions/add.js
new file mode 100644
index 00000000000..cb9b146c84f
--- /dev/null
+++ b/jstests/aggregation/expressions/add.js
@@ -0,0 +1,31 @@
+// Confirm correctness of $add evaluation in find projection.
+(function() {
+"use strict";
+
+load("jstests/aggregation/extras/utils.js"); // For assertArrayEq.
+load("jstests/libs/sbe_util.js"); // For checkSBEEnabled.
+
+const isSBEEnabled = checkSBEEnabled(db);
+if (isSBEEnabled) {
+ // Override error-code-checking APIs. We only load this when SBE is explicitly enabled, because
+ // it causes failures in the parallel suites.
+ load("jstests/libs/sbe_assert_error_override.js");
+}
+
+const coll = db.expression_add;
+coll.drop();
+assert.commandWorked(coll.insert({a: NumberInt(2), b: NumberLong(3), c: 3.5}));
+
+const testCases = [
+ [["$a", "$b", "$c"], 8.5],
+ [["$a", "$b", null], null],
+ [["$a", "$b", 5], 10],
+ [[5, "$a", "$b"], 10],
+ [["$a", 5, "$b"], 10],
+ [["$a", 20, "$c", 10], 35.5],
+];
+for (const testCase of testCases) {
+ const [addExpr, expected] = testCase;
+ assert.eq(coll.findOne({}, {sum: {$add: addExpr}, _id: 0}), {sum: expected}, testCase);
+}
+})();
diff --git a/jstests/aggregation/expressions/split.js b/jstests/aggregation/expressions/split.js
index 3425e81ecc2..1c03a5c1e2a 100644
--- a/jstests/aggregation/expressions/split.js
+++ b/jstests/aggregation/expressions/split.js
@@ -6,7 +6,13 @@
load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and testExpression.
load("jstests/libs/sbe_assert_error_override.js");
+load("jstests/libs/sbe_util.js"); // For checkSBEEnabled.
+// TODO SERVER-58095: When the classic engine is used, it will eagerly return null values, even
+// if some of its arguments are invalid. This is not the case when SBE is enabled because of the
+// order in which arguments are evaluated. In certain cases, errors will be thrown or the empty
+// string will be returned instead of null.
+const sbeEnabled = checkSBEEnabled(db);
const coll = db.split;
coll.drop();
@@ -42,8 +48,18 @@ testExpression(coll, {$split: [null, "abc"]}, null);
// Ensure that $split produces null when given missing fields as input.
testExpression(coll, {$split: ["$a", "a"]}, null);
testExpression(coll, {$split: ["a", "$a"]}, null);
+testExpression(coll, {$split: ["$a", null]}, null);
+testExpression(coll, {$split: [null, "$a"]}, null);
testExpression(coll, {$split: ["$missing", {$toLower: "$missing"}]}, null);
+// SBE expression translation will detect our empty string, whereas the classic engine will
+// detect that "$a" is missing and return null.
+if (sbeEnabled) {
+ testExpression(coll, {$split: ["", "$a"]}, [""]);
+} else {
+ testExpression(coll, {$split: ["", "$a"]}, null);
+}
+
//
// Error Code tests with constant-folding optimization.
//
@@ -78,4 +94,34 @@ pipeline = {
$project: {split: {$split: ["abc", ""]}}
};
assertErrorCode(coll, pipeline, 40087);
+
+const stringNumericArg = {
+ $split: [1, "$a"]
+};
+if (sbeEnabled) {
+ pipeline = {$project: {split: stringNumericArg}};
+ assertErrorCode(coll, pipeline, 40085);
+} else {
+ testExpression(coll, stringNumericArg, null);
+}
+
+const splitNumArg = {
+ $split: ["$b", 1]
+};
+if (sbeEnabled) {
+ pipeline = {$project: {split: splitNumArg}};
+ assertErrorCode(coll, pipeline, 40086);
+} else {
+ testExpression(coll, splitNumArg, null);
+}
+
+const emptyStringDelim = {
+ $split: ["$abc", ""]
+};
+if (sbeEnabled) {
+ pipeline = {$project: {split: emptyStringDelim}};
+ assertErrorCode(coll, pipeline, 40087);
+} else {
+ testExpression(coll, emptyStringDelim, null);
+}
})();
diff --git a/jstests/core/projection_expr_mod.js b/jstests/core/projection_expr_mod.js
index 859e99edf1b..64b7fb29757 100644
--- a/jstests/core/projection_expr_mod.js
+++ b/jstests/core/projection_expr_mod.js
@@ -72,6 +72,22 @@ error =
assert.throws(() => coll.find({}, {f: {$mod: ["$a", NumberLong(0)]}, _id: 0, n: 1}).toArray());
assert.commandFailedWithCode(error, 16610);
+// Confirm that $mod doesn't accept non-numeric input.
+error = assert.throws(
+ () => coll.find({}, {f: {$mod: ["$a", "don't accept strings!"]}, _id: 0, n: 1}).toArray());
+assert.commandFailedWithCode(error, 16611);
+
+error = assert.throws(
+ () => coll.find({}, {f: {$mod: ["don't accept strings!", "$a"]}, _id: 0, n: 1}).toArray());
+assert.commandFailedWithCode(error, 16611);
+
+error = assert.throws(() => coll.find({}, {f: {$mod: ["$a", [1, 2, 3]]}, _id: 0, n: 1}).toArray());
+assert.commandFailedWithCode(error, 16611);
+
+error = assert.throws(
+ () => coll.find({}, {f: {$mod: [{a: 1, b: 2, c: 3}, "$a"]}, _id: 0, n: 1}).toArray());
+assert.commandFailedWithCode(error, 16611);
+
// Clear collection again and reset.
assert(coll.drop());
assert.commandWorked(coll.insert({a: 10}));
@@ -83,4 +99,10 @@ assert.eq(coll.findOne({}, {f: {$mod: ["$a", -Infinity]}, _id: 0}), {f: 10});
assert.eq(coll.findOne({}, {f: {$mod: [Infinity, "$a"]}, _id: 0}), {f: NaN});
assert.eq(coll.findOne({}, {f: {$mod: [-Infinity, "$a"]}, _id: 0}), {f: NaN});
assert.eq(coll.findOne({}, {f: {$mod: [NaN, "$a"]}, _id: 0}), {f: NaN});
+
+// Confirm expected behavior for null and missing values.
+assert.eq(coll.findOne({}, {f: {$mod: ["$a", 2]}, _id: 0}), {f: 0});
+assert.eq(coll.findOne({}, {f: {$mod: [11, "$a"]}, _id: 0}), {f: 1});
+assert.eq(coll.findOne({}, {f: {$mod: [null, "$a"]}, _id: 0}), {f: null});
+assert.eq(coll.findOne({}, {f: {$mod: ["$a", null]}, _id: 0}), {f: null});
})();
diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js
index 80eee699e44..65dc04f35d7 100644
--- a/jstests/libs/sbe_assert_error_override.js
+++ b/jstests/libs/sbe_assert_error_override.js
@@ -30,7 +30,7 @@ const equivalentErrorCodesList = [
[16608, 4848401],
[16609, 5073101],
[16610, 4848403],
- [16611, 5154000],
+ [16611, 5154000, 5412904, 5412905],
[16612, 4974202],
[16702, 5073001],
[28651, 5073201],
@@ -66,9 +66,9 @@ const equivalentErrorCodesList = [
[34448, 5154305],
[34449, 5154306],
[40066, 4934200],
- [40085, 5155402],
- [40086, 5155400],
- [40087, 5155401],
+ [40085, 5155402, 5412901],
+ [40086, 5155400, 5412902],
+ [40087, 5155401, 5412903],
[40091, 5075300],
[40092, 5075301, 5075302],
[40093, 5075300],
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 23674a2ef3c..f323f6b0f58 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -124,6 +124,14 @@ struct ExpressionVisitorContext {
return {std::move(expr), stageOrLimitCoScan(std::move(stage), planNodeId)};
}
+ // Function which pops 'n' arguments from the stack. This is useful when an expression can be
+ // simplified and its remaining arguments are no longer needed.
+ void clearExprs(size_t n) {
+ for (size_t idx = 0; idx < n; ++idx) {
+ [[maybe_unused]] auto expr = popExpr();
+ }
+ };
+
StageBuilderState& state;
EvalStack<> evalStack;
@@ -306,6 +314,32 @@ std::unique_ptr<sbe::EExpression> generateRegexNullResponse(StringData exprName)
return sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0);
}
+// Return true if 'expr' is nullish, false otherwise.
+bool compileTimeNullCheck(const ExpressionConstant& expr) {
+ return expr.getValue().nullish();
+}
+
+// Return true if 'expr' is a string, false otherwise.
+bool compileTimeStringCheck(const ExpressionConstant& expr) {
+ return expr.getValue().getType() == BSONType::String;
+}
+
+// Return true if 'expr' is an empty string, false otherwise. Callers must ensure 'expr' is a
+// string.
+bool compileTimeEmptyStringCheck(const ExpressionConstant& expr) {
+ return expr.getValue().getString().empty();
+}
+
+// Return true if 'expr' is either a number, false otherwise.
+bool compileTimeNumericCheck(const ExpressionConstant& expr) {
+ return isNumericBSONType(expr.getValue().getType());
+}
+
+// Return true if 'expr' is either a number or a date, false otherwise.
+bool compileTimeNumberDateCheck(const ExpressionConstant& expr) {
+ return compileTimeNumericCheck(expr) || expr.getValue().getType() == BSONType::Date;
+}
+
class ExpressionPreVisitor final : public ExpressionConstVisitor {
public:
ExpressionPreVisitor(ExpressionVisitorContext* context) : _context{context} {}
@@ -732,12 +766,40 @@ public:
_context->ensureArity(arity);
auto frameId = _context->state.frameId();
- auto generateNotNumberOrDate = [frameId](const sbe::value::SlotId slotId) {
- sbe::EVariable var{frameId, slotId};
+ auto generateNotNumberOrDateVar = [](const sbe::EVariable& var) {
return makeBinaryOp(sbe::EPrimBinary::logicAnd,
makeNot(makeFunction("isNumber", var.clone())),
makeNot(makeFunction("isDate", var.clone())));
};
+ auto generateNotNumberOrDate =
+ [frameId, generateNotNumberOrDateVar](const sbe::value::SlotId slotId) {
+ sbe::EVariable var{frameId, slotId};
+ return generateNotNumberOrDateVar(var);
+ };
+
+ const auto notNumberOrDateError =
+ "only numbers and dates are allowed in an $add expression";
+
+ auto makeNullAndTypeCases = [&](std::vector<std::unique_ptr<sbe::EExpression>> nullChecks,
+ std::vector<std::unique_ptr<sbe::EExpression>> typeChecks,
+ ErrorCodes::Error code) -> std::vector<CaseValuePair> {
+ auto checkNullAllArguments =
+ accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr);
+ auto checkNotNumberOrDateAllArguments =
+ accumulateChecks(std::move(typeChecks), sbe::EPrimBinary::logicOr);
+ std::vector<CaseValuePair> cases;
+ cases.reserve(arity);
+ if (checkNullAllArguments) {
+ cases.emplace_back(std::move(checkNullAllArguments),
+ makeConstant(sbe::value::TypeTags::Null, 0));
+ }
+
+ if (checkNotNumberOrDateAllArguments) {
+ cases.emplace_back(std::move(checkNotNumberOrDateAllArguments),
+ sbe::makeE<sbe::EFail>(code, notNumberOrDateError));
+ }
+ return cases;
+ };
if (arity == 2) {
auto rhs = _context->popExpr();
@@ -746,42 +808,65 @@ public:
sbe::EVariable lhsVar{frameId, 0};
sbe::EVariable rhsVar{frameId, 1};
- auto addExpr = makeLocalBind(
- _context->state.frameIdGenerator,
- [&](sbe::EVariable lhsIsDate, sbe::EVariable rhsIsDate) {
- return buildMultiBranchConditional(
- CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr,
- generateNullOrMissing(frameId, 0),
- generateNullOrMissing(frameId, 1)),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
- CaseValuePair{
- makeBinaryOp(sbe::EPrimBinary::logicOr,
- generateNotNumberOrDate(0),
- generateNotNumberOrDate(1)),
- sbe::makeE<sbe::EFail>(
- ErrorCodes::Error{4974201},
- "only numbers and dates are allowed in an $add expression")},
- CaseValuePair{
- makeBinaryOp(
- sbe::EPrimBinary::logicAnd, lhsIsDate.clone(), rhsIsDate.clone()),
- sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202},
- "only one date allowed in an $add expression")},
- // An EPrimBinary::add expression, which compiles directly into an "add"
- // instruction, efficiently handles the general case for for $add with
- // exactly two operands, but when one of the operands is a date, we need to
- // use the "doubleDoubleSum" function to perform the required conversions.
- CaseValuePair{
- makeBinaryOp(
- sbe::EPrimBinary::logicOr, lhsIsDate.clone(), rhsIsDate.clone()),
- makeFunction("doubleDoubleSum", lhsVar.clone(), rhsVar.clone())},
- makeBinaryOp(sbe::EPrimBinary::add, lhsVar.clone(), rhsVar.clone()));
- },
- makeFunction("isDate", lhsVar.clone()),
- makeFunction("isDate", rhsVar.clone()));
+ const auto& children = expr->getChildren();
+ auto lhsExpr = children[0];
+ auto rhsExpr = children[1];
+ std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNull;
+ std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNotNumberOrDate;
+
+ auto constCallback = [&](const ExpressionConstant& val) {
+ if (compileTimeNullCheck(val)) {
+ _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0));
+ return true;
+ }
+ uassert(5412906, notNumberOrDateError, compileTimeNumberDateCheck(val));
+ return false;
+ };
+
+ auto nonConstCallback = [&](size_t idx) -> void {
+ checkExprsNull.push_back(generateNullOrMissing(frameId, idx));
+ checkExprsNotNumberOrDate.push_back(generateNotNumberOrDate(idx));
+ };
+ if (generateExpressionArgument(
+ lhsExpr.get(), constCallback, std::bind(nonConstCallback, 0)) ||
+ generateExpressionArgument(
+ rhsExpr.get(), constCallback, std::bind(nonConstCallback, 1))) {
+ return;
+ }
+
+ auto cases = makeNullAndTypeCases(std::move(checkExprsNull),
+ std::move(checkExprsNotNumberOrDate),
+ ErrorCodes::Error{4974201});
+
+ binds.emplace_back(makeFunction("isDate", lhsVar.clone()));
+ binds.emplace_back(makeFunction("isDate", rhsVar.clone()));
+ sbe::EVariable lhsIsDateVar{frameId, 2};
+ sbe::EVariable rhsIsDateVar{frameId, 3};
+
+ auto addExpr = [&] {
+ cases.emplace_back(CaseValuePair{
+ makeBinaryOp(
+ sbe::EPrimBinary::logicAnd, lhsIsDateVar.clone(), rhsIsDateVar.clone()),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202},
+ "only one date allowed in an $add expression")});
+
+ // An EPrimBinary::add expression, which compiles directly into an "add"
+ // instruction, efficiently handles the general case for for $add with
+ // exactly two operands, but when one of the operands is a date, we need to
+ // use the "doubleDoubleSum" function to perform the required conversions.
+ cases.emplace_back(CaseValuePair{
+ makeBinaryOp(
+ sbe::EPrimBinary::logicOr, lhsIsDateVar.clone(), rhsIsDateVar.clone()),
+ makeFunction("doubleDoubleSum", lhsVar.clone(), rhsVar.clone())});
+ return buildMultiBranchConditionalFromCaseValuePairs(
+ std::move(cases),
+ makeBinaryOp(sbe::EPrimBinary::add, lhsVar.clone(), rhsVar.clone()));
+ }();
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(addExpr)));
} else {
+ const auto& children = expr->getChildren();
std::vector<std::unique_ptr<sbe::EExpression>> binds;
std::vector<std::unique_ptr<sbe::EExpression>> argVars;
std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNull;
@@ -790,44 +875,44 @@ public:
argVars.reserve(arity);
checkExprsNull.reserve(arity);
checkExprsNotNumberOrDate.reserve(arity);
- for (size_t idx = 0; idx < arity; ++idx) {
- binds.push_back(_context->popExpr());
- argVars.push_back(sbe::makeE<sbe::EVariable>(frameId, idx));
-
- checkExprsNull.push_back(generateNullOrMissing(frameId, idx));
- checkExprsNotNumberOrDate.push_back(generateNotNumberOrDate(idx));
+ size_t nonConstArgs = 0;
+ for (size_t idx = arity; idx > 0; --idx) {
+ if (generateExpressionArgument(
+ children[idx - 1].get(),
+ [&](const ExpressionConstant& val) {
+ if (compileTimeNullCheck(val)) {
+ _context->clearExprs(idx);
+ _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0));
+ return true;
+ }
+ uassert(5412900, notNumberOrDateError, compileTimeNumberDateCheck(val));
+ argVars.push_back(_context->popExpr());
+ return false;
+ },
+ [&]() {
+ binds.push_back(_context->popExpr());
+ argVars.push_back(makeVariable(frameId, nonConstArgs));
+ checkExprsNull.push_back(generateNullOrMissing(frameId, nonConstArgs));
+ checkExprsNotNumberOrDate.push_back(
+ generateNotNumberOrDate(nonConstArgs));
+ ++nonConstArgs;
+ })) {
+ return;
+ }
}
- // At this point 'binds' vector contains arguments of $add expression in the reversed
+ // At this point 'argVars' vector contains arguments of $add expression in the reversed
// order. We need to reverse it back to perform summation in the right order below.
// Summation in different order can lead to different result because of accumulated
// precision errors from floating point types.
- std::reverse(std::begin(binds), std::end(binds));
-
- using iter_t = std::vector<std::unique_ptr<sbe::EExpression>>::iterator;
- auto checkNullAllArguments = std::accumulate(
- std::move_iterator<iter_t>(checkExprsNull.begin() + 1),
- std::move_iterator<iter_t>(checkExprsNull.end()),
- std::move(checkExprsNull.front()),
- [](auto&& acc, auto&& ex) {
- return makeBinaryOp(sbe::EPrimBinary::logicOr, std::move(acc), std::move(ex));
- });
- auto checkNotNumberOrDateAllArguments = std::accumulate(
- std::move_iterator<iter_t>(checkExprsNotNumberOrDate.begin() + 1),
- std::move_iterator<iter_t>(checkExprsNotNumberOrDate.end()),
- std::move(checkExprsNotNumberOrDate.front()),
- [](auto&& acc, auto&& ex) {
- return makeBinaryOp(sbe::EPrimBinary::logicOr, std::move(acc), std::move(ex));
- });
- auto addExpr = sbe::makeE<sbe::EIf>(
- std::move(checkNullAllArguments),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0),
- sbe::makeE<sbe::EIf>(
- std::move(checkNotNumberOrDateAllArguments),
- sbe::makeE<sbe::EFail>(
- ErrorCodes::Error{4974203},
- "only numbers and dates are allowed in an $add expression"),
- sbe::makeE<sbe::EFunction>("doubleDoubleSum", std::move(argVars))));
+ std::reverse(std::begin(argVars), std::end(argVars));
+
+ auto cases = makeNullAndTypeCases(std::move(checkExprsNull),
+ std::move(checkExprsNotNumberOrDate),
+ ErrorCodes::Error{4974203});
+ auto addExpr = buildMultiBranchConditionalFromCaseValuePairs(
+ std::move(cases),
+ sbe::makeE<sbe::EFunction>("doubleDoubleSum", std::move(argVars)));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(addExpr)));
}
@@ -2104,6 +2189,11 @@ public:
auto frameId = _context->state.frameId();
auto rhs = _context->popExpr();
auto lhs = _context->popExpr();
+
+ const auto& children = expr->getChildren();
+ auto lhsPtr = children[0].get();
+ auto rhsPtr = children[1].get();
+
auto binds = sbe::makeEs(std::move(lhs), std::move(rhs));
sbe::EVariable lhsVar{frameId, 0};
sbe::EVariable rhsVar{frameId, 1};
@@ -2111,29 +2201,98 @@ public:
// If the rhs is a small integral double, convert it to int32 to match $mod MQL semantics.
auto numericConvert32 =
sbe::makeE<sbe::ENumericConvert>(rhsVar.clone(), sbe::value::TypeTags::NumberInt32);
- auto rhsExpr = buildMultiBranchConditional(
- CaseValuePair{
- makeBinaryOp(
- sbe::EPrimBinary::logicAnd,
- sbe::makeE<sbe::ETypeMatch>(
- rhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)),
- makeNot(sbe::makeE<sbe::ETypeMatch>(
- lhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)))),
- makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())},
- rhsVar.clone());
-
- auto modExpr = buildMultiBranchConditional(
- CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr,
- generateNullOrMissing(lhsVar),
- generateNullOrMissing(rhsVar)),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
- CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr,
- generateNonNumericCheck(lhsVar),
- generateNonNumericCheck(rhsVar)),
- sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000},
- "$mod only supports numeric types")},
- makeFunction("mod", lhsVar.clone(), std::move(rhsExpr)));
+ const auto errorMsg = "$mod only supports numeric types"_sd;
+
+ std::vector<CaseValuePair> modCases;
+ std::vector<std::unique_ptr<sbe::EExpression>> nullChecks;
+ std::vector<std::unique_ptr<sbe::EExpression>> numericChecks;
+ std::vector<std::unique_ptr<sbe::EExpression>> numericConversionChecks;
+
+ bool shouldDoTypeConversion = true;
+ auto constCallback =
+ [&](const ExpressionConstant& constant, ErrorCodes::Error code, bool isLhs) {
+ if (compileTimeNullCheck(constant)) {
+ _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0));
+ return true;
+ }
+ uassert(code, errorMsg, compileTimeNumericCheck(constant));
+ auto doubleCmp = constant.getValue().getType() == BSONType::NumberDouble;
+
+ // We are checking whether the LHS is NOT a double.
+ if (isLhs) {
+ doubleCmp = !doubleCmp;
+ }
+
+ shouldDoTypeConversion = shouldDoTypeConversion && doubleCmp;
+ return false;
+ };
+
+ auto nonConstCallback = [&](const sbe::EVariable& var, bool isLhs) {
+ nullChecks.emplace_back(generateNullOrMissing(var));
+ numericChecks.emplace_back(generateNonNumericCheck(var));
+ auto typeExpr = sbe::makeE<sbe::ETypeMatch>(
+ var.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble));
+
+ // We want to know if the LHS is NOT a double.
+ if (isLhs) {
+ typeExpr = makeNot(std::move(typeExpr));
+ }
+ numericConversionChecks.emplace_back(std::move(typeExpr));
+ };
+
+ using namespace std::placeholders;
+ if (generateExpressionArgument(
+ rhsPtr,
+ std::bind(constCallback, _1, ErrorCodes::Error{5412904}, false),
+ std::bind(nonConstCallback, std::cref(rhsVar), false)) ||
+ generateExpressionArgument(
+ lhsPtr,
+ std::bind(constCallback, _1, ErrorCodes::Error{5412905}, true),
+ std::bind(nonConstCallback, std::cref(lhsVar), true))) {
+ return;
+ }
+
+ auto rhsExpr = [&]() {
+ if (!shouldDoTypeConversion) {
+ return rhsVar.clone();
+ }
+
+ if (auto typeConversionCheck = accumulateChecks(std::move(numericConversionChecks),
+ sbe::EPrimBinary::logicAnd)) {
+ return buildMultiBranchConditional(
+ CaseValuePair{
+ std::move(typeConversionCheck),
+ makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())},
+ rhsVar.clone());
+ } else {
+ // Both values are constants, so perform the type conversion on rhs manually.
+ auto rhsConstPtr = dynamic_cast<ExpressionConstant*>(rhsPtr);
+ tassert(5412907, "Right hand side for $mod should be a constant", rhsConstPtr);
+ auto lhsConstPtr = dynamic_cast<ExpressionConstant*>(lhsPtr);
+ tassert(5412908, "Left hand side for $mod should be a constant", lhsConstPtr);
+ return lhsConstPtr->getValue().getType() == NumberLong
+ ? makeConstant(sbe::value::TypeTags::NumberInt64,
+ rhsConstPtr->getValue().coerceToLong())
+ : makeConstant(sbe::value::TypeTags::NumberInt32,
+ rhsConstPtr->getValue().coerceToInt());
+ }
+ }();
+
+ if (auto checkIsNullOrMissing =
+ accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr)) {
+ modCases.emplace_back(std::move(checkIsNullOrMissing),
+ makeConstant(sbe::value::TypeTags::Null, 0));
+ }
+
+ if (auto checkNumeric =
+ accumulateChecks(std::move(numericChecks), sbe::EPrimBinary::logicOr)) {
+ modCases.emplace_back(std::move(checkNumeric),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000}, errorMsg));
+ }
+
+ auto modExpr = buildMultiBranchConditionalFromCaseValuePairs(
+ std::move(modCases), makeFunction("mod", lhsVar.clone(), std::move(rhsExpr)));
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(modExpr)));
}
@@ -2431,23 +2590,17 @@ public:
auto frameId = _context->state.frameId();
std::vector<std::unique_ptr<sbe::EExpression>> args;
std::vector<std::unique_ptr<sbe::EExpression>> binds;
- sbe::EVariable stringExpressionRef(frameId, 0);
- sbe::EVariable delimiterRef(frameId, 1);
- invariant(expr->getChildren().size() == 2);
+ auto children = expr->getChildren();
+ tassert(5412919, "$split must have exactly 2 children", children.size() == 2);
_context->ensureArity(2);
+ auto delimExpr = children[1].get();
+ auto stringExpr = children[0].get();
+
auto delimiter = _context->popExpr();
auto stringExpression = _context->popExpr();
- // Add stringExpression to arguments.
- binds.push_back(std::move(stringExpression));
- args.push_back(stringExpressionRef.clone());
-
- // Add delimiter to arguments.
- binds.push_back(std::move(delimiter));
- args.push_back(delimiterRef.clone());
-
auto [emptyStrTag, emptyStrVal] = sbe::value::makeNewString("");
auto [arrayWithEmptyStringTag, arrayWithEmptyStringVal] = sbe::value::makeNewArray();
sbe::value::ValueGuard arrayWithEmptyStringGuard{arrayWithEmptyStringTag,
@@ -2460,41 +2613,108 @@ public:
const sbe::EVariable& var) {
return makeBinaryOp(sbe::EPrimBinary::eq,
var.clone(),
- sbe::makeE<sbe::EConstant>(emptyStrTag, emptyStrVal),
+ makeConstant(emptyStrTag, emptyStrVal),
_context->state.env);
};
- auto checkIsNullOrMissing = makeBinaryOp(sbe::EPrimBinary::logicOr,
- generateNullOrMissing(stringExpressionRef),
- generateNullOrMissing(delimiterRef));
+ std::vector<CaseValuePair> cases;
+ std::vector<std::unique_ptr<sbe::EExpression>> nullChecks;
// In order to maintain MQL semantics, first check both the string expression
- // (first agument), and delimiter string (second argument) for null, undefined, or
+ // (first argument), and delimiter string (second argument) for null, undefined, or
// missing, and if either is nullish make the entire expression return null. Only
// then make further validity checks against the input. Fail if the delimiter is an empty
// string. Return [""] if the string expression is an empty string.
- auto totalSplitFunc = buildMultiBranchConditional(
- CaseValuePair{std::move(checkIsNullOrMissing),
- sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)},
- CaseValuePair{generateNonStringCheck(stringExpressionRef),
- sbe::makeE<sbe::EFail>(
- ErrorCodes::Error{5155402},
- str::stream() << "$split string expression must be a string")},
- CaseValuePair{
+ sbe::value::SlotId varId = 0;
+ const auto splitStringError = "$split string expression must be a string";
+ auto splitStringConstCallback = [&,
+ arrayWithEmptyStringTag = arrayWithEmptyStringTag,
+ arrayWithEmptyStringVal = arrayWithEmptyStringVal](
+ const ExpressionConstant& val) {
+ if (compileTimeNullCheck(val)) {
+ _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0));
+ return true;
+ }
+ uassert(5412901, splitStringError, compileTimeStringCheck(val));
+ if (compileTimeEmptyStringCheck(val)) {
+ _context->pushExpr(makeConstant(arrayWithEmptyStringTag, arrayWithEmptyStringVal));
+ return true;
+ }
+ args.push_back(std::move(stringExpression));
+ return false;
+ };
+
+ auto splitStringNonConstCallback = [&,
+ arrayWithEmptyStringTag = arrayWithEmptyStringTag,
+ arrayWithEmptyStringVal = arrayWithEmptyStringVal]() {
+ sbe::EVariable stringExpressionRef(frameId, varId);
+ varId++;
+
+ // Add string expression to arguments.
+ binds.push_back(std::move(stringExpression));
+ args.push_back(stringExpressionRef.clone());
+
+ // Add error checks.
+ nullChecks.push_back(generateNullOrMissing(stringExpressionRef));
+ cases.emplace_back(
+ generateNonStringCheck(stringExpressionRef),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155402}, splitStringError));
+
+ // Empty string check.
+ cases.emplace_back(generateIsEmptyString(stringExpressionRef),
+ makeConstant(arrayWithEmptyStringTag, arrayWithEmptyStringVal));
+ };
+
+ if (generateExpressionArgument(
+ stringExpr, splitStringConstCallback, splitStringNonConstCallback)) {
+ return;
+ }
+
+ const auto delimStringError = "$split delimiter must be a string";
+ const auto delimNonEmptyStringError = "$split delimiter must not be an empty string";
+ auto delimConstCallback = [&](const ExpressionConstant& val) {
+ if (compileTimeNullCheck(val)) {
+ _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0));
+ return true;
+ }
+ uassert(5412902, delimStringError, compileTimeStringCheck(val));
+ uassert(5412903, delimNonEmptyStringError, !compileTimeEmptyStringCheck(val));
+ args.push_back(std::move(delimiter));
+ return false;
+ };
+
+ auto delimNonConstCallback = [&]() {
+ sbe::EVariable delimiterRef(frameId, varId);
+
+ // Add delimiter to arguments.
+ binds.push_back(std::move(delimiter));
+ args.push_back(delimiterRef.clone());
+
+ // Add error checks.
+ nullChecks.push_back(generateNullOrMissing(delimiterRef));
+ cases.emplace_back(
generateNonStringCheck(delimiterRef),
- sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155400},
- str::stream() << "$split delimiter must be a string")},
- CaseValuePair{generateIsEmptyString(delimiterRef),
- sbe::makeE<sbe::EFail>(
- ErrorCodes::Error{5155401},
- str::stream() << "$split delimiter must not be an empty string")},
- sbe::makeE<sbe::EIf>(
- generateIsEmptyString(stringExpressionRef),
- sbe::makeE<sbe::EConstant>(arrayWithEmptyStringTag, arrayWithEmptyStringVal),
- sbe::makeE<sbe::EFunction>("split", std::move(args))));
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155400}, delimStringError));
+ cases.emplace_back(
+ generateIsEmptyString(delimiterRef),
+ sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155401}, delimNonEmptyStringError));
+ };
+
+ if (generateExpressionArgument(delimExpr, delimConstCallback, delimNonConstCallback)) {
+ return;
+ }
+
+ if (auto checkIsNullOrMissing =
+ accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr)) {
+ cases.emplace(cases.begin(),
+ std::move(checkIsNullOrMissing),
+ makeConstant(sbe::value::TypeTags::Null, 0));
+ }
+ auto splitFunc = buildMultiBranchConditionalFromCaseValuePairs(
+ std::move(cases), sbe::makeE<sbe::EFunction>("split", std::move(args)));
_context->pushExpr(
- sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(totalSplitFunc)));
+ sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(splitFunc)));
}
void visit(const ExpressionSqrt* expr) final {
auto frameId = _context->state.frameId();
@@ -3606,6 +3826,24 @@ private:
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(dateAddExpr)));
}
+
+ /**
+ * Generates an expression argument depending on whether 'expr' is a constant or not. Returns
+ * 'true' if 'constCallback' generates a terminal value and no further generation is required
+ * (for instance, if 'expr' is determined to be null, we can simply push 'null'), false
+ * otherwise.
+ */
+ bool generateExpressionArgument(
+ const Expression* expr,
+ const std::function<bool(const ExpressionConstant& val)>& constCallback,
+ const std::function<void()>& nonConstCallback) {
+ if (auto constExpr = dynamic_cast<const ExpressionConstant*>(expr)) {
+ return constCallback(*constExpr);
+ }
+ nonConstCallback();
+ return false;
+ }
+
void unsupportedExpression(const char* op) const {
// We're guaranteed to not fire this assertion by implementing a mechanism in the upper
// layer which directs the query to the classic engine when an unsupported expression
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index add7be68dee..d4bc4755295 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -183,6 +183,19 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(
return defaultCase;
}
+std::unique_ptr<sbe::EExpression> accumulateChecks(
+ std::vector<std::unique_ptr<sbe::EExpression>> checks, sbe::EPrimBinary::Op op) {
+ using iter_t = std::vector<std::unique_ptr<sbe::EExpression>>::iterator;
+ if (checks.empty()) {
+ return nullptr;
+ }
+ return std::accumulate(
+ std::move_iterator<iter_t>(checks.begin() + 1),
+ std::move_iterator<iter_t>(checks.end()),
+ std::move(checks.front()),
+ [&op](auto&& acc, auto&& ex) { return makeBinaryOp(op, std::move(acc), std::move(ex)); });
+}
+
std::unique_ptr<sbe::EExpression> buildMultiBranchConditionalFromCaseValuePairs(
std::vector<CaseValuePair> caseValuePairs, std::unique_ptr<sbe::EExpression> defaultValue) {
return std::accumulate(
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index f79ff0911c9..005689741a9 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -165,6 +165,12 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditionalFromCaseValuePairs(
std::vector<CaseValuePair> caseValuePairs, std::unique_ptr<sbe::EExpression> defaultValue);
/**
+ * Given a vector of 'checks', all of which return true/false, accumulate them using 'op'.
+ */
+std::unique_ptr<sbe::EExpression> accumulateChecks(
+ std::vector<std::unique_ptr<sbe::EExpression>> checks, sbe::EPrimBinary::Op op);
+
+/**
* Insert a limit stage on top of the 'input' stage.
*/
std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage,