summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Storch <david.storch@mongodb.com>2022-10-13 15:30:22 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-10-14 19:05:42 +0000
commit52799a318105b7cfb22ece697d617b5363a0f5e6 (patch)
tree14d1868f4874e2d41a58f271333a1110cfff2792
parent5525e1d93f081a8cf7cf0f39c055fea34c18f40a (diff)
downloadmongo-52799a318105b7cfb22ece697d617b5363a0f5e6.tar.gz
SERVER-70190 Fix ExpressionSwitch::optimize() to manage _children vector correctly
(cherry picked from commit 2cc7da28b9ddd1eb91516a2ba25f7dd67db88ca8)
-rw-r--r--etc/backports_required_for_multiversion_tests.yml4
-rw-r--r--jstests/aggregation/expressions/switch_errors.js19
-rw-r--r--src/mongo/db/pipeline/expression.cpp110
-rw-r--r--src/mongo/db/pipeline/expression.h39
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp60
5 files changed, 167 insertions, 65 deletions
diff --git a/etc/backports_required_for_multiversion_tests.yml b/etc/backports_required_for_multiversion_tests.yml
index 9c1813862dc..a7ab4ad9eca 100644
--- a/etc/backports_required_for_multiversion_tests.yml
+++ b/etc/backports_required_for_multiversion_tests.yml
@@ -294,6 +294,8 @@ last-continuous:
test_file: jstests/sharding/read_write_concern_defaults_application.js
- ticket: SERVER-68932
test_file: jstests/sharding/resharding_critical_section_metrics.js
+ - ticket: SERVER-70190
+ test_file: jstests/aggregation/expressions/switch_errors.js
# Tests that should only be excluded from particular suites should be listed under that suite.
suites:
@@ -719,6 +721,8 @@ last-lts:
test_file: jstests/replsets/tenant_migration_concurrent_writes_on_donor_util.js
- ticket: SERVER-69348
test_file: jstests/sharding/read_write_concern_defaults_application.js
+ - ticket: SERVER-70190
+ test_file: jstests/aggregation/expressions/switch_errors.js
# Tests that should only be excluded from particular suites should be listed under that suite.
suites:
diff --git a/jstests/aggregation/expressions/switch_errors.js b/jstests/aggregation/expressions/switch_errors.js
index 9701bc4b019..17ccd53f9d5 100644
--- a/jstests/aggregation/expressions/switch_errors.js
+++ b/jstests/aggregation/expressions/switch_errors.js
@@ -60,9 +60,26 @@ pipeline = {
};
assertErrorCode(coll, pipeline, 40068, "$switch requires at least one branch");
-coll.insert({x: 1});
+assert.commandWorked(coll.insert({x: 1}));
pipeline = {
"$project": {"output": {"$switch": {"branches": [{"case": {"$eq": ["$x", 0]}, "then": 1}]}}}
};
assertErrorCode(coll, pipeline, 40066, "$switch has no default and an input matched no case");
+
+// This query was designed to reproduce SERVER-70190. The first branch of the $switch can be
+// optimized away and the $ifNull can be optimized to 2. If the field "x" exists in the input
+// document and is truthy, then the expression should return 2. Otherwise it should throw because no
+// case statement matched and there is no "default" expression.
+pipeline = [{
+ $sortByCount: {
+ $switch: {
+ branches:
+ [{case: {$literal: false}, then: 1}, {case: "$x", then: {$ifNull: [2, "$y"]}}]
+ }
+ }
+}];
+assert.eq([{"_id": 2, "count": 1}], coll.aggregate(pipeline).toArray());
+assert.commandWorked(coll.remove({x: 1}));
+assert.commandWorked(coll.insert({z: 1}));
+assertErrorCode(coll, pipeline, 40066, "$switch has no default and an input matched no case");
}());
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index d4b7e29f9eb..8132796a3e0 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -5602,19 +5602,20 @@ const char* ExpressionSubtract::getOpName() const {
REGISTER_STABLE_EXPRESSION(switch, ExpressionSwitch::parse);
Value ExpressionSwitch::evaluate(const Document& root, Variables* variables) const {
- for (auto&& branch : _branches) {
- Value caseExpression(branch.first->evaluate(root, variables));
+ for (int i = 0; i < numBranches(); ++i) {
+ auto [caseExpr, thenExpr] = getBranch(i);
+ Value caseResult = caseExpr->evaluate(root, variables);
- if (caseExpression.coerceToBool()) {
- return branch.second->evaluate(root, variables);
+ if (caseResult.coerceToBool()) {
+ return thenExpr->evaluate(root, variables);
}
}
uassert(40066,
"$switch could not find a matching branch for an input, and no default was specified.",
- _default);
+ defaultExpr());
- return _default->evaluate(root, variables);
+ return defaultExpr()->evaluate(root, variables);
}
boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* const expCtx,
@@ -5623,7 +5624,7 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons
uassert(40060,
str::stream() << "$switch requires an object as an argument, found: "
<< typeName(expr.type()),
- expr.type() == Object);
+ expr.type() == BSONType::Object);
boost::intrusive_ptr<Expression> expDefault;
std::vector<boost::intrusive_ptr<Expression>> children;
@@ -5635,13 +5636,13 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons
uassert(40061,
str::stream() << "$switch expected an array for 'branches', found: "
<< typeName(elem.type()),
- elem.type() == Array);
+ elem.type() == BSONType::Array);
for (auto&& branch : elem.Array()) {
uassert(40062,
str::stream() << "$switch expected each branch to be an object, found: "
<< typeName(branch.type()),
- branch.type() == Object);
+ branch.type() == BSONType::Object);
boost::intrusive_ptr<Expression> switchCase, switchThen;
@@ -5673,69 +5674,67 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons
uasserted(40067, str::stream() << "$switch found an unknown argument: " << field);
}
}
+
+ // The the 'default' expression is always the final child. If no 'default' expression is
+ // provided, then the final child is nullptr.
children.push_back(std::move(expDefault));
- // Obtain references to the case and branch expressions two-by-two from the children vector,
- // ignore the last.
- std::vector<ExpressionPair> branches;
- boost::optional<boost::intrusive_ptr<Expression>&> first;
- for (auto&& child : children) {
- if (first) {
- branches.emplace_back(*first, child);
- first = boost::none;
- } else {
- first = child;
- }
- }
- uassert(40068, "$switch requires at least one branch.", !branches.empty());
+ return new ExpressionSwitch(expCtx, std::move(children));
+}
- return new ExpressionSwitch(expCtx, std::move(children), std::move(branches));
+void ExpressionSwitch::deleteBranch(int i) {
+ invariant(i >= 0);
+ invariant(i < numBranches());
+ // Delete the two elements corresponding to this branch at positions 2i and 2i + 1.
+ _children.erase(std::next(_children.begin(), i * 2), std::next(_children.begin(), i * 2 + 2));
}
boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() {
- if (_default) {
- _default = _default->optimize();
+ if (defaultExpr()) {
+ _children.back() = _children.back()->optimize();
}
- std::vector<ExpressionPair>::iterator it = _branches.begin();
- bool true_const = false;
+ bool trueConst = false;
- while (!true_const && it != _branches.end()) {
- (it->first) = (it->first)->optimize();
+ int i = 0;
+ while (!trueConst && i < numBranches()) {
+ boost::intrusive_ptr<Expression>& caseExpr = _children[i * 2];
+ boost::intrusive_ptr<Expression>& thenExpr = _children[i * 2 + 1];
+ caseExpr = caseExpr->optimize();
- if (auto* val = dynamic_cast<ExpressionConstant*>((it->first).get())) {
- if (!((val->getValue()).coerceToBool())) {
+ if (auto* val = dynamic_cast<ExpressionConstant*>(caseExpr.get())) {
+ if (!val->getValue().coerceToBool()) {
// Case is constant and evaluates to false, so it is removed.
- it = _branches.erase(it);
+ deleteBranch(i);
} else {
- // Case is constant and true so it is set to default and then removed.
- true_const = true;
-
- // Optimizing this case's then, so that default will remain optimized.
- (it->second) = (it->second)->optimize();
- _default = it->second;
- it = _branches.erase(it);
+ // Case optimized to a constant true value. Set the optimized version of the
+ // corresponding 'then' expression as the new 'default'. Break out of the loop and
+ // fall through to the logic to remove this and all subsequent branches.
+ trueConst = true;
+ _children.back() = thenExpr->optimize();
+ break;
}
} else {
// Since case is not removed from the switch, its then is now optimized.
- (it->second) = (it->second)->optimize();
- ++it;
+ thenExpr = thenExpr->optimize();
+ ++i;
}
}
// Erasing the rest of the cases because found a default true value.
- if (true_const) {
- _branches.erase(it, _branches.end());
+ if (trueConst) {
+ while (i < numBranches()) {
+ deleteBranch(i);
+ }
}
// If there are no cases, make the switch its default.
- if (_branches.size() == 0 && _default) {
- return _default;
- } else if (_branches.size() == 0) {
+ if (numBranches() == 0) {
uassert(40069,
- "One cannot execute a switch statement where all the cases evaluate to false "
- "without a default.",
- _branches.size());
+ "Cannot execute a switch statement where all the cases evaluate to false "
+ "without a default",
+ defaultExpr());
+ return _children.back();
}
return this;
@@ -5743,17 +5742,18 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() {
Value ExpressionSwitch::serialize(bool explain) const {
std::vector<Value> serializedBranches;
- serializedBranches.reserve(_branches.size());
+ serializedBranches.reserve(numBranches());
- for (auto&& branch : _branches) {
- serializedBranches.push_back(Value(Document{{"case", branch.first->serialize(explain)},
- {"then", branch.second->serialize(explain)}}));
+ for (int i = 0; i < numBranches(); ++i) {
+ auto [caseExpr, thenExpr] = getBranch(i);
+ serializedBranches.push_back(Value(Document{{"case", caseExpr->serialize(explain)},
+ {"then", thenExpr->serialize(explain)}}));
}
- if (_default) {
+ if (defaultExpr()) {
return Value(Document{{"$switch",
Document{{"branches", Value(serializedBranches)},
- {"default", _default->serialize(explain)}}}});
+ {"default", defaultExpr()->serialize(explain)}}}});
}
return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}});
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 4680181b295..0c4e077511d 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -3283,11 +3283,10 @@ public:
std::pair<boost::intrusive_ptr<Expression>&, boost::intrusive_ptr<Expression>&>;
ExpressionSwitch(ExpressionContext* const expCtx,
- std::vector<boost::intrusive_ptr<Expression>> children,
- std::vector<ExpressionPair> branches)
- : Expression(expCtx, std::move(children)),
- _default(_children.back()),
- _branches(std::move(branches)) {}
+ std::vector<boost::intrusive_ptr<Expression>> children)
+ : Expression(expCtx, std::move(children)) {
+ uassert(40068, "$switch requires at least one branch", numBranches() >= 1);
+ }
Value evaluate(const Document& root, Variables* variables) const final;
boost::intrusive_ptr<Expression> optimize() final;
@@ -3304,9 +3303,35 @@ public:
return visitor->visit(this);
}
+ /**
+ * Returns the number of cases in the switch expression. Each branch is made up of two
+ * expressions ('case' and 'then').
+ */
+ int numBranches() const {
+ return _children.size() / 2;
+ }
+
+ /**
+ * Returns a pair of expression pointers representing the 'case' and 'then' expressions for the
+ * i-th branch of the switch.
+ */
+ std::pair<const Expression*, const Expression*> getBranch(int i) const {
+ invariant(i >= 0);
+ invariant(i < numBranches());
+ return {_children[i * 2].get(), _children[i * 2 + 1].get()};
+ }
+
+ /**
+ * Returns the 'default' expression, or nullptr if there is no 'default'.
+ */
+ const Expression* defaultExpr() const {
+ return _children.back().get();
+ }
+
private:
- boost::intrusive_ptr<Expression>& _default;
- std::vector<ExpressionPair> _branches;
+ // Helper for 'optimize()'. Deletes the 'case' and 'then' children associated with the i-th
+ // branch of the switch.
+ void deleteBranch(int i);
};
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 2a9978e1bba..3e1592b12e1 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -1122,7 +1122,7 @@ TEST(ExpressionSwitch, ExpressionSwitchWithAllConstantFalsesAndNoDefaultErrors)
ASSERT_THROWS_CODE(switchExp->optimize(), AssertionException, 40069);
}
-TEST(ExpressionSwitch, ExpressionSwitchWithZeroAsConstantFalsesAndNoDefaulErrors) {
+TEST(ExpressionSwitch, ExpressionSwitchWithZeroAsConstantFalseAndNoDefaultErrors) {
auto expCtx = ExpressionContextForTest{};
VariablesParseState vps = expCtx.variablesParseState;
@@ -1234,6 +1234,62 @@ TEST(ExpressionSwitch, ExpressionSwitchWithNoConstantsShouldStayTheSame) {
ASSERT_BSONOBJ_BINARY_EQ(switchQ, expressionToBson(optimizedStaySame));
}
+// This test was designed to provide coverage for SERVER-70190, a bug in which optimizing a $switch
+// expression could leave its children vector in a bad state. By walking the tree after optimizing
+// we make sure that the expected children are found.
+TEST(ExpressionSwitch, CaseEliminationShouldLeaveTreeInWalkableState) {
+ auto expCtx = ExpressionContextForTest{};
+ VariablesParseState vps = expCtx.variablesParseState;
+
+ BSONObj switchQ = fromjson(R"(
+ {$switch: {
+ branches: [
+ {case: false, then: {$const: 0}},
+ {case: "$z", then: {$const: 1}},
+ {case: "$y", then: {$const: 3}},
+ {case: true, then: {$const: 4}},
+ {case: "$a", then: {$const: 5}},
+ {case: "$b", then: {$const: 6}},
+ {case: "$c", then: {$const: 7}}
+ ],
+ default: {$const: 8}
+ }}
+ )");
+ auto switchExp = ExpressionSwitch::parse(&expCtx, switchQ.firstElement(), vps);
+ auto optimizedExpr = switchExp->optimize();
+
+ BSONObj optimizedQ = fromjson(R"(
+ {$switch: {
+ branches: [
+ {case: "$z", then: {$const: 1}},
+ {case: "$y", then: {$const: 3}}
+ ],
+ default: {$const: 4}
+ }}
+ )");
+
+ ASSERT_BSONOBJ_BINARY_EQ(optimizedQ, expressionToBson(optimizedExpr));
+
+ // Make sure that the expression tree appears as expected when the children are traversed using
+ // a for-each loop.
+ int childNum = 0;
+ int numConstants = 0;
+ for (auto&& child : optimizedExpr->getChildren()) {
+ // Children 0 and 2 are field path expressions, whereas 1, 3, and 4 are constants.
+ auto constExpr = dynamic_cast<ExpressionConstant*>(child.get());
+ if (constExpr) {
+ ASSERT_VALUE_EQ(constExpr->getValue(), Value{childNum});
+ ++numConstants;
+ } else {
+ ASSERT(dynamic_cast<ExpressionFieldPath*>(child.get()));
+ }
+ ++childNum;
+ }
+ // We should have seen 5 children total, 3 of which are constants.
+ ASSERT_EQ(childNum, 5);
+ ASSERT_EQ(numConstants, 3);
+}
+
TEST(ExpressionArray, ExpressionArrayShouldOptimizeSubExpressionToExpressionConstant) {
auto expCtx = ExpressionContextForTest{};
VariablesParseState vps = expCtx.variablesParseState;
@@ -4175,7 +4231,7 @@ TEST(ExpressionEncryptedBetweenTest, ParseRoundTrip) {
auto serializedExpr = expr->serialize(false);
auto expectedExpr = fromjson(R"({$encryptedBetween: [
{$const: "age"},
- { $const:
+ { $const:
{$binary: {
"base64": "ZW5jcnlwdGVkIHBheWxvYWQ=",
"subType": "6"