diff options
17 files changed, 494 insertions, 128 deletions
diff --git a/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml index aa9e84dbc1b..098952c993e 100644 --- a/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/multi_shard_local_read_write_multi_stmt_txn_jscore_passthrough.yml @@ -159,6 +159,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml index a2d14686ceb..649bc2b3891 100644 --- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_jscore_passthrough.yml @@ -174,6 +174,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml index b61aed35737..afc436e4803 100644 --- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_kill_primary_jscore_passthrough.yml @@ -169,6 +169,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml index 4fe4a9b5343..5846ec7162a 100644 --- a/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/multi_shard_multi_stmt_txn_stepdown_primary_jscore_passthrough.yml @@ -170,6 +170,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml b/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml index 1cb8898a421..623cbe2d10b 100644 --- a/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml +++ b/buildscripts/resmokeconfig/suites/multi_stmt_txn_jscore_passthrough_with_migration.yml @@ -182,6 +182,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml index 19b3237d99b..28276078610 100644 --- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_jscore_passthrough.yml @@ -122,6 +122,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml index 658df1cca30..519acfb6f2e 100644 --- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_kill_primary_jscore_passthrough.yml @@ -110,6 +110,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml index 13227f2f413..25b4e55b041 100644 --- a/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/replica_sets_multi_stmt_txn_stepdown_jscore_passthrough.yml @@ -109,6 +109,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml index 10436896401..812df6897f8 100644 --- a/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/sharded_multi_stmt_txn_jscore_passthrough.yml @@ -146,6 +146,7 @@ selector: - jstests/core/updatel.js - jstests/core/write_result.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js # Trick for bypassing mongo shell validation in the test doesn't work because txn_override # retry logic will hit the shell validation. diff --git a/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml b/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml index 5c813fd7a53..86761f02edf 100644 --- a/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml +++ b/buildscripts/resmokeconfig/suites/tenant_migration_multi_stmt_txn_jscore_passthrough.yml @@ -227,6 +227,7 @@ selector: - jstests/core/field_name_validation.js - jstests/core/insert_illegal_doc.js - jstests/core/positional_projection.js + - jstests/core/projection_expr_mod.js - jstests/core/push_sort.js - jstests/core/update_dbref.js diff --git a/jstests/aggregation/expressions/add.js b/jstests/aggregation/expressions/add.js new file mode 100644 index 00000000000..cb9b146c84f --- /dev/null +++ b/jstests/aggregation/expressions/add.js @@ -0,0 +1,31 @@ +// Confirm correctness of $add evaluation in find projection. +(function() { +"use strict"; + +load("jstests/aggregation/extras/utils.js"); // For assertArrayEq. +load("jstests/libs/sbe_util.js"); // For checkSBEEnabled. + +const isSBEEnabled = checkSBEEnabled(db); +if (isSBEEnabled) { + // Override error-code-checking APIs. We only load this when SBE is explicitly enabled, because + // it causes failures in the parallel suites. + load("jstests/libs/sbe_assert_error_override.js"); +} + +const coll = db.expression_add; +coll.drop(); +assert.commandWorked(coll.insert({a: NumberInt(2), b: NumberLong(3), c: 3.5})); + +const testCases = [ + [["$a", "$b", "$c"], 8.5], + [["$a", "$b", null], null], + [["$a", "$b", 5], 10], + [[5, "$a", "$b"], 10], + [["$a", 5, "$b"], 10], + [["$a", 20, "$c", 10], 35.5], +]; +for (const testCase of testCases) { + const [addExpr, expected] = testCase; + assert.eq(coll.findOne({}, {sum: {$add: addExpr}, _id: 0}), {sum: expected}, testCase); +} +})(); diff --git a/jstests/aggregation/expressions/split.js b/jstests/aggregation/expressions/split.js index 3425e81ecc2..1c03a5c1e2a 100644 --- a/jstests/aggregation/expressions/split.js +++ b/jstests/aggregation/expressions/split.js @@ -6,7 +6,13 @@ load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and testExpression. load("jstests/libs/sbe_assert_error_override.js"); +load("jstests/libs/sbe_util.js"); // For checkSBEEnabled. +// TODO SERVER-58095: When the classic engine is used, it will eagerly return null values, even +// if some of its arguments are invalid. This is not the case when SBE is enabled because of the +// order in which arguments are evaluated. In certain cases, errors will be thrown or the empty +// string will be returned instead of null. +const sbeEnabled = checkSBEEnabled(db); const coll = db.split; coll.drop(); @@ -42,8 +48,18 @@ testExpression(coll, {$split: [null, "abc"]}, null); // Ensure that $split produces null when given missing fields as input. testExpression(coll, {$split: ["$a", "a"]}, null); testExpression(coll, {$split: ["a", "$a"]}, null); +testExpression(coll, {$split: ["$a", null]}, null); +testExpression(coll, {$split: [null, "$a"]}, null); testExpression(coll, {$split: ["$missing", {$toLower: "$missing"}]}, null); +// SBE expression translation will detect our empty string, whereas the classic engine will +// detect that "$a" is missing and return null. +if (sbeEnabled) { + testExpression(coll, {$split: ["", "$a"]}, [""]); +} else { + testExpression(coll, {$split: ["", "$a"]}, null); +} + // // Error Code tests with constant-folding optimization. // @@ -78,4 +94,34 @@ pipeline = { $project: {split: {$split: ["abc", ""]}} }; assertErrorCode(coll, pipeline, 40087); + +const stringNumericArg = { + $split: [1, "$a"] +}; +if (sbeEnabled) { + pipeline = {$project: {split: stringNumericArg}}; + assertErrorCode(coll, pipeline, 40085); +} else { + testExpression(coll, stringNumericArg, null); +} + +const splitNumArg = { + $split: ["$b", 1] +}; +if (sbeEnabled) { + pipeline = {$project: {split: splitNumArg}}; + assertErrorCode(coll, pipeline, 40086); +} else { + testExpression(coll, splitNumArg, null); +} + +const emptyStringDelim = { + $split: ["$abc", ""] +}; +if (sbeEnabled) { + pipeline = {$project: {split: emptyStringDelim}}; + assertErrorCode(coll, pipeline, 40087); +} else { + testExpression(coll, emptyStringDelim, null); +} })(); diff --git a/jstests/core/projection_expr_mod.js b/jstests/core/projection_expr_mod.js index 859e99edf1b..64b7fb29757 100644 --- a/jstests/core/projection_expr_mod.js +++ b/jstests/core/projection_expr_mod.js @@ -72,6 +72,22 @@ error = assert.throws(() => coll.find({}, {f: {$mod: ["$a", NumberLong(0)]}, _id: 0, n: 1}).toArray()); assert.commandFailedWithCode(error, 16610); +// Confirm that $mod doesn't accept non-numeric input. +error = assert.throws( + () => coll.find({}, {f: {$mod: ["$a", "don't accept strings!"]}, _id: 0, n: 1}).toArray()); +assert.commandFailedWithCode(error, 16611); + +error = assert.throws( + () => coll.find({}, {f: {$mod: ["don't accept strings!", "$a"]}, _id: 0, n: 1}).toArray()); +assert.commandFailedWithCode(error, 16611); + +error = assert.throws(() => coll.find({}, {f: {$mod: ["$a", [1, 2, 3]]}, _id: 0, n: 1}).toArray()); +assert.commandFailedWithCode(error, 16611); + +error = assert.throws( + () => coll.find({}, {f: {$mod: [{a: 1, b: 2, c: 3}, "$a"]}, _id: 0, n: 1}).toArray()); +assert.commandFailedWithCode(error, 16611); + // Clear collection again and reset. assert(coll.drop()); assert.commandWorked(coll.insert({a: 10})); @@ -83,4 +99,10 @@ assert.eq(coll.findOne({}, {f: {$mod: ["$a", -Infinity]}, _id: 0}), {f: 10}); assert.eq(coll.findOne({}, {f: {$mod: [Infinity, "$a"]}, _id: 0}), {f: NaN}); assert.eq(coll.findOne({}, {f: {$mod: [-Infinity, "$a"]}, _id: 0}), {f: NaN}); assert.eq(coll.findOne({}, {f: {$mod: [NaN, "$a"]}, _id: 0}), {f: NaN}); + +// Confirm expected behavior for null and missing values. +assert.eq(coll.findOne({}, {f: {$mod: ["$a", 2]}, _id: 0}), {f: 0}); +assert.eq(coll.findOne({}, {f: {$mod: [11, "$a"]}, _id: 0}), {f: 1}); +assert.eq(coll.findOne({}, {f: {$mod: [null, "$a"]}, _id: 0}), {f: null}); +assert.eq(coll.findOne({}, {f: {$mod: ["$a", null]}, _id: 0}), {f: null}); })(); diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js index 80eee699e44..65dc04f35d7 100644 --- a/jstests/libs/sbe_assert_error_override.js +++ b/jstests/libs/sbe_assert_error_override.js @@ -30,7 +30,7 @@ const equivalentErrorCodesList = [ [16608, 4848401], [16609, 5073101], [16610, 4848403], - [16611, 5154000], + [16611, 5154000, 5412904, 5412905], [16612, 4974202], [16702, 5073001], [28651, 5073201], @@ -66,9 +66,9 @@ const equivalentErrorCodesList = [ [34448, 5154305], [34449, 5154306], [40066, 4934200], - [40085, 5155402], - [40086, 5155400], - [40087, 5155401], + [40085, 5155402, 5412901], + [40086, 5155400, 5412902], + [40087, 5155401, 5412903], [40091, 5075300], [40092, 5075301, 5075302], [40093, 5075300], diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 23674a2ef3c..f323f6b0f58 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -124,6 +124,14 @@ struct ExpressionVisitorContext { return {std::move(expr), stageOrLimitCoScan(std::move(stage), planNodeId)}; } + // Function which pops 'n' arguments from the stack. This is useful when an expression can be + // simplified and its remaining arguments are no longer needed. + void clearExprs(size_t n) { + for (size_t idx = 0; idx < n; ++idx) { + [[maybe_unused]] auto expr = popExpr(); + } + }; + StageBuilderState& state; EvalStack<> evalStack; @@ -306,6 +314,32 @@ std::unique_ptr<sbe::EExpression> generateRegexNullResponse(StringData exprName) return sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0); } +// Return true if 'expr' is nullish, false otherwise. +bool compileTimeNullCheck(const ExpressionConstant& expr) { + return expr.getValue().nullish(); +} + +// Return true if 'expr' is a string, false otherwise. +bool compileTimeStringCheck(const ExpressionConstant& expr) { + return expr.getValue().getType() == BSONType::String; +} + +// Return true if 'expr' is an empty string, false otherwise. Callers must ensure 'expr' is a +// string. +bool compileTimeEmptyStringCheck(const ExpressionConstant& expr) { + return expr.getValue().getString().empty(); +} + +// Return true if 'expr' is either a number, false otherwise. +bool compileTimeNumericCheck(const ExpressionConstant& expr) { + return isNumericBSONType(expr.getValue().getType()); +} + +// Return true if 'expr' is either a number or a date, false otherwise. +bool compileTimeNumberDateCheck(const ExpressionConstant& expr) { + return compileTimeNumericCheck(expr) || expr.getValue().getType() == BSONType::Date; +} + class ExpressionPreVisitor final : public ExpressionConstVisitor { public: ExpressionPreVisitor(ExpressionVisitorContext* context) : _context{context} {} @@ -732,12 +766,40 @@ public: _context->ensureArity(arity); auto frameId = _context->state.frameId(); - auto generateNotNumberOrDate = [frameId](const sbe::value::SlotId slotId) { - sbe::EVariable var{frameId, slotId}; + auto generateNotNumberOrDateVar = [](const sbe::EVariable& var) { return makeBinaryOp(sbe::EPrimBinary::logicAnd, makeNot(makeFunction("isNumber", var.clone())), makeNot(makeFunction("isDate", var.clone()))); }; + auto generateNotNumberOrDate = + [frameId, generateNotNumberOrDateVar](const sbe::value::SlotId slotId) { + sbe::EVariable var{frameId, slotId}; + return generateNotNumberOrDateVar(var); + }; + + const auto notNumberOrDateError = + "only numbers and dates are allowed in an $add expression"; + + auto makeNullAndTypeCases = [&](std::vector<std::unique_ptr<sbe::EExpression>> nullChecks, + std::vector<std::unique_ptr<sbe::EExpression>> typeChecks, + ErrorCodes::Error code) -> std::vector<CaseValuePair> { + auto checkNullAllArguments = + accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr); + auto checkNotNumberOrDateAllArguments = + accumulateChecks(std::move(typeChecks), sbe::EPrimBinary::logicOr); + std::vector<CaseValuePair> cases; + cases.reserve(arity); + if (checkNullAllArguments) { + cases.emplace_back(std::move(checkNullAllArguments), + makeConstant(sbe::value::TypeTags::Null, 0)); + } + + if (checkNotNumberOrDateAllArguments) { + cases.emplace_back(std::move(checkNotNumberOrDateAllArguments), + sbe::makeE<sbe::EFail>(code, notNumberOrDateError)); + } + return cases; + }; if (arity == 2) { auto rhs = _context->popExpr(); @@ -746,42 +808,65 @@ public: sbe::EVariable lhsVar{frameId, 0}; sbe::EVariable rhsVar{frameId, 1}; - auto addExpr = makeLocalBind( - _context->state.frameIdGenerator, - [&](sbe::EVariable lhsIsDate, sbe::EVariable rhsIsDate) { - return buildMultiBranchConditional( - CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr, - generateNullOrMissing(frameId, 0), - generateNullOrMissing(frameId, 1)), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, - CaseValuePair{ - makeBinaryOp(sbe::EPrimBinary::logicOr, - generateNotNumberOrDate(0), - generateNotNumberOrDate(1)), - sbe::makeE<sbe::EFail>( - ErrorCodes::Error{4974201}, - "only numbers and dates are allowed in an $add expression")}, - CaseValuePair{ - makeBinaryOp( - sbe::EPrimBinary::logicAnd, lhsIsDate.clone(), rhsIsDate.clone()), - sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202}, - "only one date allowed in an $add expression")}, - // An EPrimBinary::add expression, which compiles directly into an "add" - // instruction, efficiently handles the general case for for $add with - // exactly two operands, but when one of the operands is a date, we need to - // use the "doubleDoubleSum" function to perform the required conversions. - CaseValuePair{ - makeBinaryOp( - sbe::EPrimBinary::logicOr, lhsIsDate.clone(), rhsIsDate.clone()), - makeFunction("doubleDoubleSum", lhsVar.clone(), rhsVar.clone())}, - makeBinaryOp(sbe::EPrimBinary::add, lhsVar.clone(), rhsVar.clone())); - }, - makeFunction("isDate", lhsVar.clone()), - makeFunction("isDate", rhsVar.clone())); + const auto& children = expr->getChildren(); + auto lhsExpr = children[0]; + auto rhsExpr = children[1]; + std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNull; + std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNotNumberOrDate; + + auto constCallback = [&](const ExpressionConstant& val) { + if (compileTimeNullCheck(val)) { + _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0)); + return true; + } + uassert(5412906, notNumberOrDateError, compileTimeNumberDateCheck(val)); + return false; + }; + + auto nonConstCallback = [&](size_t idx) -> void { + checkExprsNull.push_back(generateNullOrMissing(frameId, idx)); + checkExprsNotNumberOrDate.push_back(generateNotNumberOrDate(idx)); + }; + if (generateExpressionArgument( + lhsExpr.get(), constCallback, std::bind(nonConstCallback, 0)) || + generateExpressionArgument( + rhsExpr.get(), constCallback, std::bind(nonConstCallback, 1))) { + return; + } + + auto cases = makeNullAndTypeCases(std::move(checkExprsNull), + std::move(checkExprsNotNumberOrDate), + ErrorCodes::Error{4974201}); + + binds.emplace_back(makeFunction("isDate", lhsVar.clone())); + binds.emplace_back(makeFunction("isDate", rhsVar.clone())); + sbe::EVariable lhsIsDateVar{frameId, 2}; + sbe::EVariable rhsIsDateVar{frameId, 3}; + + auto addExpr = [&] { + cases.emplace_back(CaseValuePair{ + makeBinaryOp( + sbe::EPrimBinary::logicAnd, lhsIsDateVar.clone(), rhsIsDateVar.clone()), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{4974202}, + "only one date allowed in an $add expression")}); + + // An EPrimBinary::add expression, which compiles directly into an "add" + // instruction, efficiently handles the general case for for $add with + // exactly two operands, but when one of the operands is a date, we need to + // use the "doubleDoubleSum" function to perform the required conversions. + cases.emplace_back(CaseValuePair{ + makeBinaryOp( + sbe::EPrimBinary::logicOr, lhsIsDateVar.clone(), rhsIsDateVar.clone()), + makeFunction("doubleDoubleSum", lhsVar.clone(), rhsVar.clone())}); + return buildMultiBranchConditionalFromCaseValuePairs( + std::move(cases), + makeBinaryOp(sbe::EPrimBinary::add, lhsVar.clone(), rhsVar.clone())); + }(); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(addExpr))); } else { + const auto& children = expr->getChildren(); std::vector<std::unique_ptr<sbe::EExpression>> binds; std::vector<std::unique_ptr<sbe::EExpression>> argVars; std::vector<std::unique_ptr<sbe::EExpression>> checkExprsNull; @@ -790,44 +875,44 @@ public: argVars.reserve(arity); checkExprsNull.reserve(arity); checkExprsNotNumberOrDate.reserve(arity); - for (size_t idx = 0; idx < arity; ++idx) { - binds.push_back(_context->popExpr()); - argVars.push_back(sbe::makeE<sbe::EVariable>(frameId, idx)); - - checkExprsNull.push_back(generateNullOrMissing(frameId, idx)); - checkExprsNotNumberOrDate.push_back(generateNotNumberOrDate(idx)); + size_t nonConstArgs = 0; + for (size_t idx = arity; idx > 0; --idx) { + if (generateExpressionArgument( + children[idx - 1].get(), + [&](const ExpressionConstant& val) { + if (compileTimeNullCheck(val)) { + _context->clearExprs(idx); + _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0)); + return true; + } + uassert(5412900, notNumberOrDateError, compileTimeNumberDateCheck(val)); + argVars.push_back(_context->popExpr()); + return false; + }, + [&]() { + binds.push_back(_context->popExpr()); + argVars.push_back(makeVariable(frameId, nonConstArgs)); + checkExprsNull.push_back(generateNullOrMissing(frameId, nonConstArgs)); + checkExprsNotNumberOrDate.push_back( + generateNotNumberOrDate(nonConstArgs)); + ++nonConstArgs; + })) { + return; + } } - // At this point 'binds' vector contains arguments of $add expression in the reversed + // At this point 'argVars' vector contains arguments of $add expression in the reversed // order. We need to reverse it back to perform summation in the right order below. // Summation in different order can lead to different result because of accumulated // precision errors from floating point types. - std::reverse(std::begin(binds), std::end(binds)); - - using iter_t = std::vector<std::unique_ptr<sbe::EExpression>>::iterator; - auto checkNullAllArguments = std::accumulate( - std::move_iterator<iter_t>(checkExprsNull.begin() + 1), - std::move_iterator<iter_t>(checkExprsNull.end()), - std::move(checkExprsNull.front()), - [](auto&& acc, auto&& ex) { - return makeBinaryOp(sbe::EPrimBinary::logicOr, std::move(acc), std::move(ex)); - }); - auto checkNotNumberOrDateAllArguments = std::accumulate( - std::move_iterator<iter_t>(checkExprsNotNumberOrDate.begin() + 1), - std::move_iterator<iter_t>(checkExprsNotNumberOrDate.end()), - std::move(checkExprsNotNumberOrDate.front()), - [](auto&& acc, auto&& ex) { - return makeBinaryOp(sbe::EPrimBinary::logicOr, std::move(acc), std::move(ex)); - }); - auto addExpr = sbe::makeE<sbe::EIf>( - std::move(checkNullAllArguments), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0), - sbe::makeE<sbe::EIf>( - std::move(checkNotNumberOrDateAllArguments), - sbe::makeE<sbe::EFail>( - ErrorCodes::Error{4974203}, - "only numbers and dates are allowed in an $add expression"), - sbe::makeE<sbe::EFunction>("doubleDoubleSum", std::move(argVars)))); + std::reverse(std::begin(argVars), std::end(argVars)); + + auto cases = makeNullAndTypeCases(std::move(checkExprsNull), + std::move(checkExprsNotNumberOrDate), + ErrorCodes::Error{4974203}); + auto addExpr = buildMultiBranchConditionalFromCaseValuePairs( + std::move(cases), + sbe::makeE<sbe::EFunction>("doubleDoubleSum", std::move(argVars))); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(addExpr))); } @@ -2104,6 +2189,11 @@ public: auto frameId = _context->state.frameId(); auto rhs = _context->popExpr(); auto lhs = _context->popExpr(); + + const auto& children = expr->getChildren(); + auto lhsPtr = children[0].get(); + auto rhsPtr = children[1].get(); + auto binds = sbe::makeEs(std::move(lhs), std::move(rhs)); sbe::EVariable lhsVar{frameId, 0}; sbe::EVariable rhsVar{frameId, 1}; @@ -2111,29 +2201,98 @@ public: // If the rhs is a small integral double, convert it to int32 to match $mod MQL semantics. auto numericConvert32 = sbe::makeE<sbe::ENumericConvert>(rhsVar.clone(), sbe::value::TypeTags::NumberInt32); - auto rhsExpr = buildMultiBranchConditional( - CaseValuePair{ - makeBinaryOp( - sbe::EPrimBinary::logicAnd, - sbe::makeE<sbe::ETypeMatch>( - rhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)), - makeNot(sbe::makeE<sbe::ETypeMatch>( - lhsVar.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)))), - makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())}, - rhsVar.clone()); - - auto modExpr = buildMultiBranchConditional( - CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr, - generateNullOrMissing(lhsVar), - generateNullOrMissing(rhsVar)), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, - CaseValuePair{makeBinaryOp(sbe::EPrimBinary::logicOr, - generateNonNumericCheck(lhsVar), - generateNonNumericCheck(rhsVar)), - sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000}, - "$mod only supports numeric types")}, - makeFunction("mod", lhsVar.clone(), std::move(rhsExpr))); + const auto errorMsg = "$mod only supports numeric types"_sd; + + std::vector<CaseValuePair> modCases; + std::vector<std::unique_ptr<sbe::EExpression>> nullChecks; + std::vector<std::unique_ptr<sbe::EExpression>> numericChecks; + std::vector<std::unique_ptr<sbe::EExpression>> numericConversionChecks; + + bool shouldDoTypeConversion = true; + auto constCallback = + [&](const ExpressionConstant& constant, ErrorCodes::Error code, bool isLhs) { + if (compileTimeNullCheck(constant)) { + _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0)); + return true; + } + uassert(code, errorMsg, compileTimeNumericCheck(constant)); + auto doubleCmp = constant.getValue().getType() == BSONType::NumberDouble; + + // We are checking whether the LHS is NOT a double. + if (isLhs) { + doubleCmp = !doubleCmp; + } + + shouldDoTypeConversion = shouldDoTypeConversion && doubleCmp; + return false; + }; + + auto nonConstCallback = [&](const sbe::EVariable& var, bool isLhs) { + nullChecks.emplace_back(generateNullOrMissing(var)); + numericChecks.emplace_back(generateNonNumericCheck(var)); + auto typeExpr = sbe::makeE<sbe::ETypeMatch>( + var.clone(), getBSONTypeMask(sbe::value::TypeTags::NumberDouble)); + + // We want to know if the LHS is NOT a double. + if (isLhs) { + typeExpr = makeNot(std::move(typeExpr)); + } + numericConversionChecks.emplace_back(std::move(typeExpr)); + }; + + using namespace std::placeholders; + if (generateExpressionArgument( + rhsPtr, + std::bind(constCallback, _1, ErrorCodes::Error{5412904}, false), + std::bind(nonConstCallback, std::cref(rhsVar), false)) || + generateExpressionArgument( + lhsPtr, + std::bind(constCallback, _1, ErrorCodes::Error{5412905}, true), + std::bind(nonConstCallback, std::cref(lhsVar), true))) { + return; + } + + auto rhsExpr = [&]() { + if (!shouldDoTypeConversion) { + return rhsVar.clone(); + } + + if (auto typeConversionCheck = accumulateChecks(std::move(numericConversionChecks), + sbe::EPrimBinary::logicAnd)) { + return buildMultiBranchConditional( + CaseValuePair{ + std::move(typeConversionCheck), + makeFunction("fillEmpty", std::move(numericConvert32), rhsVar.clone())}, + rhsVar.clone()); + } else { + // Both values are constants, so perform the type conversion on rhs manually. + auto rhsConstPtr = dynamic_cast<ExpressionConstant*>(rhsPtr); + tassert(5412907, "Right hand side for $mod should be a constant", rhsConstPtr); + auto lhsConstPtr = dynamic_cast<ExpressionConstant*>(lhsPtr); + tassert(5412908, "Left hand side for $mod should be a constant", lhsConstPtr); + return lhsConstPtr->getValue().getType() == NumberLong + ? makeConstant(sbe::value::TypeTags::NumberInt64, + rhsConstPtr->getValue().coerceToLong()) + : makeConstant(sbe::value::TypeTags::NumberInt32, + rhsConstPtr->getValue().coerceToInt()); + } + }(); + + if (auto checkIsNullOrMissing = + accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr)) { + modCases.emplace_back(std::move(checkIsNullOrMissing), + makeConstant(sbe::value::TypeTags::Null, 0)); + } + + if (auto checkNumeric = + accumulateChecks(std::move(numericChecks), sbe::EPrimBinary::logicOr)) { + modCases.emplace_back(std::move(checkNumeric), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5154000}, errorMsg)); + } + + auto modExpr = buildMultiBranchConditionalFromCaseValuePairs( + std::move(modCases), makeFunction("mod", lhsVar.clone(), std::move(rhsExpr))); _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(modExpr))); } @@ -2431,23 +2590,17 @@ public: auto frameId = _context->state.frameId(); std::vector<std::unique_ptr<sbe::EExpression>> args; std::vector<std::unique_ptr<sbe::EExpression>> binds; - sbe::EVariable stringExpressionRef(frameId, 0); - sbe::EVariable delimiterRef(frameId, 1); - invariant(expr->getChildren().size() == 2); + auto children = expr->getChildren(); + tassert(5412919, "$split must have exactly 2 children", children.size() == 2); _context->ensureArity(2); + auto delimExpr = children[1].get(); + auto stringExpr = children[0].get(); + auto delimiter = _context->popExpr(); auto stringExpression = _context->popExpr(); - // Add stringExpression to arguments. - binds.push_back(std::move(stringExpression)); - args.push_back(stringExpressionRef.clone()); - - // Add delimiter to arguments. - binds.push_back(std::move(delimiter)); - args.push_back(delimiterRef.clone()); - auto [emptyStrTag, emptyStrVal] = sbe::value::makeNewString(""); auto [arrayWithEmptyStringTag, arrayWithEmptyStringVal] = sbe::value::makeNewArray(); sbe::value::ValueGuard arrayWithEmptyStringGuard{arrayWithEmptyStringTag, @@ -2460,41 +2613,108 @@ public: const sbe::EVariable& var) { return makeBinaryOp(sbe::EPrimBinary::eq, var.clone(), - sbe::makeE<sbe::EConstant>(emptyStrTag, emptyStrVal), + makeConstant(emptyStrTag, emptyStrVal), _context->state.env); }; - auto checkIsNullOrMissing = makeBinaryOp(sbe::EPrimBinary::logicOr, - generateNullOrMissing(stringExpressionRef), - generateNullOrMissing(delimiterRef)); + std::vector<CaseValuePair> cases; + std::vector<std::unique_ptr<sbe::EExpression>> nullChecks; // In order to maintain MQL semantics, first check both the string expression - // (first agument), and delimiter string (second argument) for null, undefined, or + // (first argument), and delimiter string (second argument) for null, undefined, or // missing, and if either is nullish make the entire expression return null. Only // then make further validity checks against the input. Fail if the delimiter is an empty // string. Return [""] if the string expression is an empty string. - auto totalSplitFunc = buildMultiBranchConditional( - CaseValuePair{std::move(checkIsNullOrMissing), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0)}, - CaseValuePair{generateNonStringCheck(stringExpressionRef), - sbe::makeE<sbe::EFail>( - ErrorCodes::Error{5155402}, - str::stream() << "$split string expression must be a string")}, - CaseValuePair{ + sbe::value::SlotId varId = 0; + const auto splitStringError = "$split string expression must be a string"; + auto splitStringConstCallback = [&, + arrayWithEmptyStringTag = arrayWithEmptyStringTag, + arrayWithEmptyStringVal = arrayWithEmptyStringVal]( + const ExpressionConstant& val) { + if (compileTimeNullCheck(val)) { + _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0)); + return true; + } + uassert(5412901, splitStringError, compileTimeStringCheck(val)); + if (compileTimeEmptyStringCheck(val)) { + _context->pushExpr(makeConstant(arrayWithEmptyStringTag, arrayWithEmptyStringVal)); + return true; + } + args.push_back(std::move(stringExpression)); + return false; + }; + + auto splitStringNonConstCallback = [&, + arrayWithEmptyStringTag = arrayWithEmptyStringTag, + arrayWithEmptyStringVal = arrayWithEmptyStringVal]() { + sbe::EVariable stringExpressionRef(frameId, varId); + varId++; + + // Add string expression to arguments. + binds.push_back(std::move(stringExpression)); + args.push_back(stringExpressionRef.clone()); + + // Add error checks. + nullChecks.push_back(generateNullOrMissing(stringExpressionRef)); + cases.emplace_back( + generateNonStringCheck(stringExpressionRef), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155402}, splitStringError)); + + // Empty string check. + cases.emplace_back(generateIsEmptyString(stringExpressionRef), + makeConstant(arrayWithEmptyStringTag, arrayWithEmptyStringVal)); + }; + + if (generateExpressionArgument( + stringExpr, splitStringConstCallback, splitStringNonConstCallback)) { + return; + } + + const auto delimStringError = "$split delimiter must be a string"; + const auto delimNonEmptyStringError = "$split delimiter must not be an empty string"; + auto delimConstCallback = [&](const ExpressionConstant& val) { + if (compileTimeNullCheck(val)) { + _context->pushExpr(makeConstant(sbe::value::TypeTags::Null, 0)); + return true; + } + uassert(5412902, delimStringError, compileTimeStringCheck(val)); + uassert(5412903, delimNonEmptyStringError, !compileTimeEmptyStringCheck(val)); + args.push_back(std::move(delimiter)); + return false; + }; + + auto delimNonConstCallback = [&]() { + sbe::EVariable delimiterRef(frameId, varId); + + // Add delimiter to arguments. + binds.push_back(std::move(delimiter)); + args.push_back(delimiterRef.clone()); + + // Add error checks. + nullChecks.push_back(generateNullOrMissing(delimiterRef)); + cases.emplace_back( generateNonStringCheck(delimiterRef), - sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155400}, - str::stream() << "$split delimiter must be a string")}, - CaseValuePair{generateIsEmptyString(delimiterRef), - sbe::makeE<sbe::EFail>( - ErrorCodes::Error{5155401}, - str::stream() << "$split delimiter must not be an empty string")}, - sbe::makeE<sbe::EIf>( - generateIsEmptyString(stringExpressionRef), - sbe::makeE<sbe::EConstant>(arrayWithEmptyStringTag, arrayWithEmptyStringVal), - sbe::makeE<sbe::EFunction>("split", std::move(args)))); + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155400}, delimStringError)); + cases.emplace_back( + generateIsEmptyString(delimiterRef), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5155401}, delimNonEmptyStringError)); + }; + + if (generateExpressionArgument(delimExpr, delimConstCallback, delimNonConstCallback)) { + return; + } + + if (auto checkIsNullOrMissing = + accumulateChecks(std::move(nullChecks), sbe::EPrimBinary::logicOr)) { + cases.emplace(cases.begin(), + std::move(checkIsNullOrMissing), + makeConstant(sbe::value::TypeTags::Null, 0)); + } + auto splitFunc = buildMultiBranchConditionalFromCaseValuePairs( + std::move(cases), sbe::makeE<sbe::EFunction>("split", std::move(args))); _context->pushExpr( - sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(totalSplitFunc))); + sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(splitFunc))); } void visit(const ExpressionSqrt* expr) final { auto frameId = _context->state.frameId(); @@ -3606,6 +3826,24 @@ private: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(dateAddExpr))); } + + /** + * Generates an expression argument depending on whether 'expr' is a constant or not. Returns + * 'true' if 'constCallback' generates a terminal value and no further generation is required + * (for instance, if 'expr' is determined to be null, we can simply push 'null'), false + * otherwise. + */ + bool generateExpressionArgument( + const Expression* expr, + const std::function<bool(const ExpressionConstant& val)>& constCallback, + const std::function<void()>& nonConstCallback) { + if (auto constExpr = dynamic_cast<const ExpressionConstant*>(expr)) { + return constCallback(*constExpr); + } + nonConstCallback(); + return false; + } + void unsupportedExpression(const char* op) const { // We're guaranteed to not fire this assertion by implementing a mechanism in the upper // layer which directs the query to the classic engine when an unsupported expression diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index add7be68dee..d4bc4755295 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -183,6 +183,19 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( return defaultCase; } +std::unique_ptr<sbe::EExpression> accumulateChecks( + std::vector<std::unique_ptr<sbe::EExpression>> checks, sbe::EPrimBinary::Op op) { + using iter_t = std::vector<std::unique_ptr<sbe::EExpression>>::iterator; + if (checks.empty()) { + return nullptr; + } + return std::accumulate( + std::move_iterator<iter_t>(checks.begin() + 1), + std::move_iterator<iter_t>(checks.end()), + std::move(checks.front()), + [&op](auto&& acc, auto&& ex) { return makeBinaryOp(op, std::move(acc), std::move(ex)); }); +} + std::unique_ptr<sbe::EExpression> buildMultiBranchConditionalFromCaseValuePairs( std::vector<CaseValuePair> caseValuePairs, std::unique_ptr<sbe::EExpression> defaultValue) { return std::accumulate( diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index f79ff0911c9..005689741a9 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -165,6 +165,12 @@ std::unique_ptr<sbe::EExpression> buildMultiBranchConditionalFromCaseValuePairs( std::vector<CaseValuePair> caseValuePairs, std::unique_ptr<sbe::EExpression> defaultValue); /** + * Given a vector of 'checks', all of which return true/false, accumulate them using 'op'. + */ +std::unique_ptr<sbe::EExpression> accumulateChecks( + std::vector<std::unique_ptr<sbe::EExpression>> checks, sbe::EPrimBinary::Op op); + +/** * Insert a limit stage on top of the 'input' stage. */ std::unique_ptr<sbe::PlanStage> makeLimitTree(std::unique_ptr<sbe::PlanStage> inputStage, |