diff options
-rw-r--r-- | buildscripts/tests/test_burn_in_tags.py | 2 | ||||
-rw-r--r-- | jstests/aggregation/bugs/ifnull.js | 69 | ||||
-rw-r--r-- | jstests/aggregation/ifnull.js | 84 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 19 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 5 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 55 |
6 files changed, 143 insertions, 91 deletions
diff --git a/buildscripts/tests/test_burn_in_tags.py b/buildscripts/tests/test_burn_in_tags.py index 698f88178f4..5e0a68411fd 100644 --- a/buildscripts/tests/test_burn_in_tags.py +++ b/buildscripts/tests/test_burn_in_tags.py @@ -125,7 +125,7 @@ class TestGenerateEvgTasks(unittest.TestCase): "display_task_name": "aggregation_mongos_passthrough", "resmoke_args": "--suites=aggregation_mongos_passthrough --storageEngine=wiredTiger", - "tests": ["jstests/aggregation/bugs/ifnull.js"], + "tests": ["jstests/aggregation/ifnull.js"], "use_multiversion": None } } # yapf: disable diff --git a/jstests/aggregation/bugs/ifnull.js b/jstests/aggregation/bugs/ifnull.js deleted file mode 100644 index d74d4f0c7dd..00000000000 --- a/jstests/aggregation/bugs/ifnull.js +++ /dev/null @@ -1,69 +0,0 @@ -// Cannot implicitly shard accessed collections because of following errmsg: A single -// update/delete on a sharded collection must contain an exact match on _id or contain the shard -// key. -// @tags: [ -// assumes_unsharded_collection, -// ] - -// $ifNull returns the result of the first expression if not null or undefined, otherwise of the -// second expression. -load('jstests/aggregation/extras/utils.js'); - -t = db.jstests_aggregation_ifnull; -t.drop(); - -t.save({}); - -function assertError(expectedErrorCode, ifNullSpec) { - assertErrorCode(t, {$project: {a: {$ifNull: ifNullSpec}}}, expectedErrorCode); -} - -function assertResult(expectedResult, arg0, arg1) { - var res = t.aggregate({$project: {a: {$ifNull: [arg0, arg1]}}}).toArray()[0]; - assert.eq(expectedResult, res.a); -} - -// Wrong number of args. -assertError(16020, []); -assertError(16020, [1]); -assertError(16020, [null]); -assertError(16020, [1, 1, 1]); -assertError(16020, [1, 1, null]); -assertError(16020, [1, 1, undefined]); - -// First arg non null. -assertResult(1, 1, 2); -assertResult(2, 2, 1); -assertResult(false, false, 1); -assertResult('', '', 1); -assertResult([], [], 1); -assertResult({}, {}, 1); -assertResult(1, 1, null); -assertResult(2, 2, undefined); - -// First arg null. -assertResult(2, null, 2); -assertResult(1, null, 1); -assertResult(null, null, null); -assertResult(undefined, null, undefined); - -// First arg undefined. -assertResult(2, undefined, 2); -assertResult(1, undefined, 1); -assertResult(null, undefined, null); -assertResult(undefined, undefined, undefined); - -// Computed expression. -assertResult(3, {$add: [1, 2]}, 5); -assertResult(20, '$missingField', {$multiply: [4, 5]}); - -// Divide/mod by 0. -assertError(16608, [{$divide: [1, 0]}, 0]); -assertError(16610, [{$mod: [1, 0]}, 0]); - -// Nested. -t.drop(); -t.save({d: 'foo'}); -assertResult('foo', '$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}); -t.update({}, {$set: {b: 'bar'}}); -assertResult('bar', '$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}); diff --git a/jstests/aggregation/ifnull.js b/jstests/aggregation/ifnull.js new file mode 100644 index 00000000000..fba99901345 --- /dev/null +++ b/jstests/aggregation/ifnull.js @@ -0,0 +1,84 @@ +(function() { +"use strict"; + +load('jstests/aggregation/extras/utils.js'); + +const t = db.jstests_aggregation_ifnull; +t.drop(); +assert.commandWorked(t.insertOne({ + zero: 0, + one: 1, + two: 2, + three: 3, + my_false: false, + my_str: '', + my_null: null, + my_undefined: undefined, + my_obj: {}, + my_list: [] +})); + +function assertError(expectedErrorCode, ifNullSpec) { + assertErrorCode(t, {$project: {a: {$ifNull: ifNullSpec}}}, expectedErrorCode); +} + +function assertResult(expectedResult, ifNullSpec) { + const res = t.aggregate({$project: {_id: 0, a: {$ifNull: ifNullSpec}}}).toArray()[0]; + assert.docEq({a: expectedResult}, res); +} + +// Wrong number of args. +assertError(1257300, []); +assertError(1257300, ['$one']); +assertError(1257300, ['$my_null']); + +// First arg non null. +assertResult(1, ['$one', '$two']); +assertResult(2, ['$two', '$one']); +assertResult(false, ['$my_false', '$one']); +assertResult('', ['$my_str', '$one']); +assertResult([], ['$my_list', '$one']); +assertResult({}, ['$my_obj', '$one']); +assertResult(1, ['$one', '$my_null']); +assertResult(2, ['$two', '$my_undefined']); +assertResult(1, ['$one', '$two', '$three']); +assertResult(1, ['$one', '$two', '$my_null']); +assertResult(1, ['$one', '$my_null', '$two']); + +// First arg null. +assertResult(2, ['$my_null', '$two']); +assertResult(1, ['$my_null', '$one']); +assertResult(null, ['$my_null', '$my_null']); +assertResult(false, ['$my_null', '$my_false', '$one']); +assertResult(false, ['$my_null', '$my_null', '$my_false']); +assertResult(null, ['$my_null', '$my_null', '$my_null', '$my_null']); + +// First arg undefined. +assertResult(2, ['$my_undefined', '$two']); +assertResult(1, ['$my_undefined', '$one']); +assertResult(null, ['$my_undefined', '$my_null']); +assertResult(false, ['$my_undefined', '$my_false', '$one']); +assertResult('', ['$my_undefined', '$my_null', '$missingField', '$my_str', '$two']); + +// Computed expression. +assertResult(2, [{$add: ['$one', '$one']}, '$three']); +assertResult(6, ['$missingField', {$multiply: ['$two', '$three']}]); +assertResult(2, [{$add: ['$one', '$one']}, '$three', '$zero']); + +// Divide/mod by 0. +assertError([16608, 4848401], [{$divide: ['$one', '$zero']}, '$zero']); +assertError([16610, 4848403], [{$mod: ['$one', '$zero']}, '$zero']); + +// Nested. +assert(t.drop()); +assert.commandWorked(t.insertOne({d: 'foo'})); +assertResult('foo', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]); +assert.commandWorked(t.updateMany({}, {$set: {b: 'bar'}})); +assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]); +assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}, '$e']); + +// Return undefined. +// TODO SERVER-52703: Commented out for now; these tests return an empty doc in SBE +// assertResult(undefined, ['$my_null', '$my_undefined']); +// assertResult(undefined, ['$my_undefined', '$my_undefined']); +}());
\ No newline at end of file diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 6d95102b5e4..06481c5d953 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2956,13 +2956,20 @@ const char* ExpressionMultiply::getOpName() const { /* ----------------------- ExpressionIfNull ---------------------------- */ -Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const { - Value pLeft(_children[0]->evaluate(root, variables)); - if (!pLeft.nullish()) - return pLeft; +void ExpressionIfNull::validateArguments(const ExpressionVector& args) const { + uassert(1257300, + str::stream() << "$ifNull needs at least two arguments, had: " << args.size(), + args.size() >= 2); +} - Value pRight(_children[1]->evaluate(root, variables)); - return pRight; +Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const { + const size_t n = _children.size(); + for (size_t i = 0; i < n; ++i) { + Value pValue(_children[i]->evaluate(root, variables)); + if (!pValue.nullish() || i == n - 1) + return pValue; + } + return Value(); } REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse); diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 8c326db6360..cbcf74dd618 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -1601,13 +1601,14 @@ public: }; -class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> { +class ExpressionIfNull final : public ExpressionVariadic<ExpressionIfNull> { public: explicit ExpressionIfNull(ExpressionContext* const expCtx) - : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {} + : ExpressionVariadic<ExpressionIfNull>(expCtx) {} Value evaluate(const Document& root, Variables* variables) const final; const char* getOpName() const final; + void validateArguments(const ExpressionVector& args) const final; void acceptVisitor(ExpressionVisitor* visitor) final { return visitor->visit(this); diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 938982b2fd2..ddecb47bb38 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -376,7 +376,9 @@ public: void visit(ExpressionFieldPath* expr) final {} void visit(ExpressionFilter* expr) final {} void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -545,7 +547,9 @@ public: _context->evalStack.emplaceFrame(EvalStage{}); } void visit(ExpressionFloor* expr) final {} - void visit(ExpressionIfNull* expr) final {} + void visit(ExpressionIfNull* expr) final { + _context->evalStack.emplaceFrame(EvalStage{}); + } void visit(ExpressionIn* expr) final {} void visit(ExpressionIndexOfArray* expr) final {} void visit(ExpressionIndexOfBytes* expr) final {} @@ -1813,21 +1817,46 @@ public: sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr))); } void visit(ExpressionIfNull* expr) final { - _context->ensureArity(2); + auto numChildren = expr->getChildren().size(); + invariant(numChildren >= 2); - auto replacementIfNull = _context->popExpr(); - auto input = _context->popExpr(); + std::vector<EvalExprStagePair> branches; + branches.reserve(numChildren); + for (size_t i = 0; i < numChildren; ++i) { + auto [expr, stage] = _context->popFrame(); + branches.emplace_back(std::move(expr), std::move(stage)); + } + std::reverse(branches.begin(), branches.end()); - auto frameId = _context->frameIdGenerator->generate(); - auto binds = sbe::makeEs(std::move(input)); - sbe::EVariable inputRef(frameId, 0); + // Prepare to create limit-1/union with N branches (where N is the number of operands). Each + // branch will be evaluated from left to right until one of the branches produces a value. + auto branchFn = [](EvalExpr evalExpr, + EvalStage stage, + PlanNodeId planNodeId, + sbe::value::SlotIdGenerator* slotIdGenerator) { + auto slot = slotIdGenerator->generate(); + stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr()); + + // Create a FilterStage for each branch (except the last one). If a branch's filter + // condition is true, it will "short-circuit" the evaluation process. For ifNull, + // short-circuiting should happen if the current variable is not null or missing. + auto filterExpr = makeNot(generateNullOrMissing(slot)); + auto filterStage = + makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId); + + // Set the current expression as the output to be returned if short-circuiting occurs. + return std::make_pair(slot, std::move(filterStage)); + }; - // If input is null or missing, return replacement expression. Otherwise, return input. - auto ifNullExpr = sbe::makeE<sbe::EIf>( - generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone()); + auto [resultExpr, opStage] = generateSingleResultUnion( + std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator); - _context->pushExpr( - sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr))); + auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(), + std::move(opStage), + _context->planNodeId, + _context->getLexicalEnvironment()); + + _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage)); } void visit(ExpressionIn* expr) final { unsupportedExpression(expr->getOpName()); |