summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--buildscripts/tests/test_burn_in_tags.py2
-rw-r--r--jstests/aggregation/bugs/ifnull.js69
-rw-r--r--jstests/aggregation/ifnull.js84
-rw-r--r--src/mongo/db/pipeline/expression.cpp19
-rw-r--r--src/mongo/db/pipeline/expression.h5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp55
6 files changed, 143 insertions, 91 deletions
diff --git a/buildscripts/tests/test_burn_in_tags.py b/buildscripts/tests/test_burn_in_tags.py
index 698f88178f4..5e0a68411fd 100644
--- a/buildscripts/tests/test_burn_in_tags.py
+++ b/buildscripts/tests/test_burn_in_tags.py
@@ -125,7 +125,7 @@ class TestGenerateEvgTasks(unittest.TestCase):
"display_task_name": "aggregation_mongos_passthrough",
"resmoke_args":
"--suites=aggregation_mongos_passthrough --storageEngine=wiredTiger",
- "tests": ["jstests/aggregation/bugs/ifnull.js"],
+ "tests": ["jstests/aggregation/ifnull.js"],
"use_multiversion": None
}
} # yapf: disable
diff --git a/jstests/aggregation/bugs/ifnull.js b/jstests/aggregation/bugs/ifnull.js
deleted file mode 100644
index d74d4f0c7dd..00000000000
--- a/jstests/aggregation/bugs/ifnull.js
+++ /dev/null
@@ -1,69 +0,0 @@
-// Cannot implicitly shard accessed collections because of following errmsg: A single
-// update/delete on a sharded collection must contain an exact match on _id or contain the shard
-// key.
-// @tags: [
-// assumes_unsharded_collection,
-// ]
-
-// $ifNull returns the result of the first expression if not null or undefined, otherwise of the
-// second expression.
-load('jstests/aggregation/extras/utils.js');
-
-t = db.jstests_aggregation_ifnull;
-t.drop();
-
-t.save({});
-
-function assertError(expectedErrorCode, ifNullSpec) {
- assertErrorCode(t, {$project: {a: {$ifNull: ifNullSpec}}}, expectedErrorCode);
-}
-
-function assertResult(expectedResult, arg0, arg1) {
- var res = t.aggregate({$project: {a: {$ifNull: [arg0, arg1]}}}).toArray()[0];
- assert.eq(expectedResult, res.a);
-}
-
-// Wrong number of args.
-assertError(16020, []);
-assertError(16020, [1]);
-assertError(16020, [null]);
-assertError(16020, [1, 1, 1]);
-assertError(16020, [1, 1, null]);
-assertError(16020, [1, 1, undefined]);
-
-// First arg non null.
-assertResult(1, 1, 2);
-assertResult(2, 2, 1);
-assertResult(false, false, 1);
-assertResult('', '', 1);
-assertResult([], [], 1);
-assertResult({}, {}, 1);
-assertResult(1, 1, null);
-assertResult(2, 2, undefined);
-
-// First arg null.
-assertResult(2, null, 2);
-assertResult(1, null, 1);
-assertResult(null, null, null);
-assertResult(undefined, null, undefined);
-
-// First arg undefined.
-assertResult(2, undefined, 2);
-assertResult(1, undefined, 1);
-assertResult(null, undefined, null);
-assertResult(undefined, undefined, undefined);
-
-// Computed expression.
-assertResult(3, {$add: [1, 2]}, 5);
-assertResult(20, '$missingField', {$multiply: [4, 5]});
-
-// Divide/mod by 0.
-assertError(16608, [{$divide: [1, 0]}, 0]);
-assertError(16610, [{$mod: [1, 0]}, 0]);
-
-// Nested.
-t.drop();
-t.save({d: 'foo'});
-assertResult('foo', '$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]});
-t.update({}, {$set: {b: 'bar'}});
-assertResult('bar', '$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]});
diff --git a/jstests/aggregation/ifnull.js b/jstests/aggregation/ifnull.js
new file mode 100644
index 00000000000..fba99901345
--- /dev/null
+++ b/jstests/aggregation/ifnull.js
@@ -0,0 +1,84 @@
+(function() {
+"use strict";
+
+load('jstests/aggregation/extras/utils.js');
+
+const t = db.jstests_aggregation_ifnull;
+t.drop();
+assert.commandWorked(t.insertOne({
+ zero: 0,
+ one: 1,
+ two: 2,
+ three: 3,
+ my_false: false,
+ my_str: '',
+ my_null: null,
+ my_undefined: undefined,
+ my_obj: {},
+ my_list: []
+}));
+
+function assertError(expectedErrorCode, ifNullSpec) {
+ assertErrorCode(t, {$project: {a: {$ifNull: ifNullSpec}}}, expectedErrorCode);
+}
+
+function assertResult(expectedResult, ifNullSpec) {
+ const res = t.aggregate({$project: {_id: 0, a: {$ifNull: ifNullSpec}}}).toArray()[0];
+ assert.docEq({a: expectedResult}, res);
+}
+
+// Wrong number of args.
+assertError(1257300, []);
+assertError(1257300, ['$one']);
+assertError(1257300, ['$my_null']);
+
+// First arg non null.
+assertResult(1, ['$one', '$two']);
+assertResult(2, ['$two', '$one']);
+assertResult(false, ['$my_false', '$one']);
+assertResult('', ['$my_str', '$one']);
+assertResult([], ['$my_list', '$one']);
+assertResult({}, ['$my_obj', '$one']);
+assertResult(1, ['$one', '$my_null']);
+assertResult(2, ['$two', '$my_undefined']);
+assertResult(1, ['$one', '$two', '$three']);
+assertResult(1, ['$one', '$two', '$my_null']);
+assertResult(1, ['$one', '$my_null', '$two']);
+
+// First arg null.
+assertResult(2, ['$my_null', '$two']);
+assertResult(1, ['$my_null', '$one']);
+assertResult(null, ['$my_null', '$my_null']);
+assertResult(false, ['$my_null', '$my_false', '$one']);
+assertResult(false, ['$my_null', '$my_null', '$my_false']);
+assertResult(null, ['$my_null', '$my_null', '$my_null', '$my_null']);
+
+// First arg undefined.
+assertResult(2, ['$my_undefined', '$two']);
+assertResult(1, ['$my_undefined', '$one']);
+assertResult(null, ['$my_undefined', '$my_null']);
+assertResult(false, ['$my_undefined', '$my_false', '$one']);
+assertResult('', ['$my_undefined', '$my_null', '$missingField', '$my_str', '$two']);
+
+// Computed expression.
+assertResult(2, [{$add: ['$one', '$one']}, '$three']);
+assertResult(6, ['$missingField', {$multiply: ['$two', '$three']}]);
+assertResult(2, [{$add: ['$one', '$one']}, '$three', '$zero']);
+
+// Divide/mod by 0.
+assertError([16608, 4848401], [{$divide: ['$one', '$zero']}, '$zero']);
+assertError([16610, 4848403], [{$mod: ['$one', '$zero']}, '$zero']);
+
+// Nested.
+assert(t.drop());
+assert.commandWorked(t.insertOne({d: 'foo'}));
+assertResult('foo', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]);
+assert.commandWorked(t.updateMany({}, {$set: {b: 'bar'}}));
+assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]);
+assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}, '$e']);
+
+// Return undefined.
+// TODO SERVER-52703: Commented out for now; these tests return an empty doc in SBE
+// assertResult(undefined, ['$my_null', '$my_undefined']);
+// assertResult(undefined, ['$my_undefined', '$my_undefined']);
+}()); \ No newline at end of file
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 6d95102b5e4..06481c5d953 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -2956,13 +2956,20 @@ const char* ExpressionMultiply::getOpName() const {
/* ----------------------- ExpressionIfNull ---------------------------- */
-Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
- Value pLeft(_children[0]->evaluate(root, variables));
- if (!pLeft.nullish())
- return pLeft;
+void ExpressionIfNull::validateArguments(const ExpressionVector& args) const {
+ uassert(1257300,
+ str::stream() << "$ifNull needs at least two arguments, had: " << args.size(),
+ args.size() >= 2);
+}
- Value pRight(_children[1]->evaluate(root, variables));
- return pRight;
+Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
+ const size_t n = _children.size();
+ for (size_t i = 0; i < n; ++i) {
+ Value pValue(_children[i]->evaluate(root, variables));
+ if (!pValue.nullish() || i == n - 1)
+ return pValue;
+ }
+ return Value();
}
REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse);
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 8c326db6360..cbcf74dd618 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -1601,13 +1601,14 @@ public:
};
-class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> {
+class ExpressionIfNull final : public ExpressionVariadic<ExpressionIfNull> {
public:
explicit ExpressionIfNull(ExpressionContext* const expCtx)
- : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {}
+ : ExpressionVariadic<ExpressionIfNull>(expCtx) {}
Value evaluate(const Document& root, Variables* variables) const final;
const char* getOpName() const final;
+ void validateArguments(const ExpressionVector& args) const final;
void acceptVisitor(ExpressionVisitor* visitor) final {
return visitor->visit(this);
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 938982b2fd2..ddecb47bb38 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -376,7 +376,9 @@ public:
void visit(ExpressionFieldPath* expr) final {}
void visit(ExpressionFilter* expr) final {}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -545,7 +547,9 @@ public:
_context->evalStack.emplaceFrame(EvalStage{});
}
void visit(ExpressionFloor* expr) final {}
- void visit(ExpressionIfNull* expr) final {}
+ void visit(ExpressionIfNull* expr) final {
+ _context->evalStack.emplaceFrame(EvalStage{});
+ }
void visit(ExpressionIn* expr) final {}
void visit(ExpressionIndexOfArray* expr) final {}
void visit(ExpressionIndexOfBytes* expr) final {}
@@ -1813,21 +1817,46 @@ public:
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(floorExpr)));
}
void visit(ExpressionIfNull* expr) final {
- _context->ensureArity(2);
+ auto numChildren = expr->getChildren().size();
+ invariant(numChildren >= 2);
- auto replacementIfNull = _context->popExpr();
- auto input = _context->popExpr();
+ std::vector<EvalExprStagePair> branches;
+ branches.reserve(numChildren);
+ for (size_t i = 0; i < numChildren; ++i) {
+ auto [expr, stage] = _context->popFrame();
+ branches.emplace_back(std::move(expr), std::move(stage));
+ }
+ std::reverse(branches.begin(), branches.end());
- auto frameId = _context->frameIdGenerator->generate();
- auto binds = sbe::makeEs(std::move(input));
- sbe::EVariable inputRef(frameId, 0);
+ // Prepare to create limit-1/union with N branches (where N is the number of operands). Each
+ // branch will be evaluated from left to right until one of the branches produces a value.
+ auto branchFn = [](EvalExpr evalExpr,
+ EvalStage stage,
+ PlanNodeId planNodeId,
+ sbe::value::SlotIdGenerator* slotIdGenerator) {
+ auto slot = slotIdGenerator->generate();
+ stage = makeProject(std::move(stage), planNodeId, slot, evalExpr.extractExpr());
+
+ // Create a FilterStage for each branch (except the last one). If a branch's filter
+ // condition is true, it will "short-circuit" the evaluation process. For ifNull,
+ // short-circuiting should happen if the current variable is not null or missing.
+ auto filterExpr = makeNot(generateNullOrMissing(slot));
+ auto filterStage =
+ makeFilter<false>(std::move(stage), std::move(filterExpr), planNodeId);
+
+ // Set the current expression as the output to be returned if short-circuiting occurs.
+ return std::make_pair(slot, std::move(filterStage));
+ };
- // If input is null or missing, return replacement expression. Otherwise, return input.
- auto ifNullExpr = sbe::makeE<sbe::EIf>(
- generateNullOrMissing(frameId, 0), std::move(replacementIfNull), inputRef.clone());
+ auto [resultExpr, opStage] = generateSingleResultUnion(
+ std::move(branches), branchFn, _context->planNodeId, _context->slotIdGenerator);
- _context->pushExpr(
- sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(ifNullExpr)));
+ auto loopJoinStage = makeLoopJoin(_context->extractCurrentEvalStage(),
+ std::move(opStage),
+ _context->planNodeId,
+ _context->getLexicalEnvironment());
+
+ _context->pushExpr(resultExpr.extractExpr(), std::move(loopJoinStage));
}
void visit(ExpressionIn* expr) final {
unsupportedExpression(expr->getOpName());