diff options
author | David Storch <david.storch@mongodb.com> | 2022-10-13 15:30:22 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-10-14 19:05:42 +0000 |
commit | 52799a318105b7cfb22ece697d617b5363a0f5e6 (patch) | |
tree | 14d1868f4874e2d41a58f271333a1110cfff2792 | |
parent | 5525e1d93f081a8cf7cf0f39c055fea34c18f40a (diff) | |
download | mongo-52799a318105b7cfb22ece697d617b5363a0f5e6.tar.gz |
SERVER-70190 Fix ExpressionSwitch::optimize() to manage _children vector correctly
(cherry picked from commit 2cc7da28b9ddd1eb91516a2ba25f7dd67db88ca8)
-rw-r--r-- | etc/backports_required_for_multiversion_tests.yml | 4 | ||||
-rw-r--r-- | jstests/aggregation/expressions/switch_errors.js | 19 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 110 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 39 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 60 |
5 files changed, 167 insertions, 65 deletions
diff --git a/etc/backports_required_for_multiversion_tests.yml b/etc/backports_required_for_multiversion_tests.yml index 9c1813862dc..a7ab4ad9eca 100644 --- a/etc/backports_required_for_multiversion_tests.yml +++ b/etc/backports_required_for_multiversion_tests.yml @@ -294,6 +294,8 @@ last-continuous: test_file: jstests/sharding/read_write_concern_defaults_application.js - ticket: SERVER-68932 test_file: jstests/sharding/resharding_critical_section_metrics.js + - ticket: SERVER-70190 + test_file: jstests/aggregation/expressions/switch_errors.js # Tests that should only be excluded from particular suites should be listed under that suite. suites: @@ -719,6 +721,8 @@ last-lts: test_file: jstests/replsets/tenant_migration_concurrent_writes_on_donor_util.js - ticket: SERVER-69348 test_file: jstests/sharding/read_write_concern_defaults_application.js + - ticket: SERVER-70190 + test_file: jstests/aggregation/expressions/switch_errors.js # Tests that should only be excluded from particular suites should be listed under that suite. suites: diff --git a/jstests/aggregation/expressions/switch_errors.js b/jstests/aggregation/expressions/switch_errors.js index 9701bc4b019..17ccd53f9d5 100644 --- a/jstests/aggregation/expressions/switch_errors.js +++ b/jstests/aggregation/expressions/switch_errors.js @@ -60,9 +60,26 @@ pipeline = { }; assertErrorCode(coll, pipeline, 40068, "$switch requires at least one branch"); -coll.insert({x: 1}); +assert.commandWorked(coll.insert({x: 1})); pipeline = { "$project": {"output": {"$switch": {"branches": [{"case": {"$eq": ["$x", 0]}, "then": 1}]}}} }; assertErrorCode(coll, pipeline, 40066, "$switch has no default and an input matched no case"); + +// This query was designed to reproduce SERVER-70190. The first branch of the $switch can be +// optimized away and the $ifNull can be optimized to 2. If the field "x" exists in the input +// document and is truthy, then the expression should return 2. Otherwise it should throw because no +// case statement matched and there is no "default" expression. +pipeline = [{ + $sortByCount: { + $switch: { + branches: + [{case: {$literal: false}, then: 1}, {case: "$x", then: {$ifNull: [2, "$y"]}}] + } + } +}]; +assert.eq([{"_id": 2, "count": 1}], coll.aggregate(pipeline).toArray()); +assert.commandWorked(coll.remove({x: 1})); +assert.commandWorked(coll.insert({z: 1})); +assertErrorCode(coll, pipeline, 40066, "$switch has no default and an input matched no case"); }()); diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index d4b7e29f9eb..8132796a3e0 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -5602,19 +5602,20 @@ const char* ExpressionSubtract::getOpName() const { REGISTER_STABLE_EXPRESSION(switch, ExpressionSwitch::parse); Value ExpressionSwitch::evaluate(const Document& root, Variables* variables) const { - for (auto&& branch : _branches) { - Value caseExpression(branch.first->evaluate(root, variables)); + for (int i = 0; i < numBranches(); ++i) { + auto [caseExpr, thenExpr] = getBranch(i); + Value caseResult = caseExpr->evaluate(root, variables); - if (caseExpression.coerceToBool()) { - return branch.second->evaluate(root, variables); + if (caseResult.coerceToBool()) { + return thenExpr->evaluate(root, variables); } } uassert(40066, "$switch could not find a matching branch for an input, and no default was specified.", - _default); + defaultExpr()); - return _default->evaluate(root, variables); + return defaultExpr()->evaluate(root, variables); } boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* const expCtx, @@ -5623,7 +5624,7 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons uassert(40060, str::stream() << "$switch requires an object as an argument, found: " << typeName(expr.type()), - expr.type() == Object); + expr.type() == BSONType::Object); boost::intrusive_ptr<Expression> expDefault; std::vector<boost::intrusive_ptr<Expression>> children; @@ -5635,13 +5636,13 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons uassert(40061, str::stream() << "$switch expected an array for 'branches', found: " << typeName(elem.type()), - elem.type() == Array); + elem.type() == BSONType::Array); for (auto&& branch : elem.Array()) { uassert(40062, str::stream() << "$switch expected each branch to be an object, found: " << typeName(branch.type()), - branch.type() == Object); + branch.type() == BSONType::Object); boost::intrusive_ptr<Expression> switchCase, switchThen; @@ -5673,69 +5674,67 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(ExpressionContext* cons uasserted(40067, str::stream() << "$switch found an unknown argument: " << field); } } + + // The the 'default' expression is always the final child. If no 'default' expression is + // provided, then the final child is nullptr. children.push_back(std::move(expDefault)); - // Obtain references to the case and branch expressions two-by-two from the children vector, - // ignore the last. - std::vector<ExpressionPair> branches; - boost::optional<boost::intrusive_ptr<Expression>&> first; - for (auto&& child : children) { - if (first) { - branches.emplace_back(*first, child); - first = boost::none; - } else { - first = child; - } - } - uassert(40068, "$switch requires at least one branch.", !branches.empty()); + return new ExpressionSwitch(expCtx, std::move(children)); +} - return new ExpressionSwitch(expCtx, std::move(children), std::move(branches)); +void ExpressionSwitch::deleteBranch(int i) { + invariant(i >= 0); + invariant(i < numBranches()); + // Delete the two elements corresponding to this branch at positions 2i and 2i + 1. + _children.erase(std::next(_children.begin(), i * 2), std::next(_children.begin(), i * 2 + 2)); } boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() { - if (_default) { - _default = _default->optimize(); + if (defaultExpr()) { + _children.back() = _children.back()->optimize(); } - std::vector<ExpressionPair>::iterator it = _branches.begin(); - bool true_const = false; + bool trueConst = false; - while (!true_const && it != _branches.end()) { - (it->first) = (it->first)->optimize(); + int i = 0; + while (!trueConst && i < numBranches()) { + boost::intrusive_ptr<Expression>& caseExpr = _children[i * 2]; + boost::intrusive_ptr<Expression>& thenExpr = _children[i * 2 + 1]; + caseExpr = caseExpr->optimize(); - if (auto* val = dynamic_cast<ExpressionConstant*>((it->first).get())) { - if (!((val->getValue()).coerceToBool())) { + if (auto* val = dynamic_cast<ExpressionConstant*>(caseExpr.get())) { + if (!val->getValue().coerceToBool()) { // Case is constant and evaluates to false, so it is removed. - it = _branches.erase(it); + deleteBranch(i); } else { - // Case is constant and true so it is set to default and then removed. - true_const = true; - - // Optimizing this case's then, so that default will remain optimized. - (it->second) = (it->second)->optimize(); - _default = it->second; - it = _branches.erase(it); + // Case optimized to a constant true value. Set the optimized version of the + // corresponding 'then' expression as the new 'default'. Break out of the loop and + // fall through to the logic to remove this and all subsequent branches. + trueConst = true; + _children.back() = thenExpr->optimize(); + break; } } else { // Since case is not removed from the switch, its then is now optimized. - (it->second) = (it->second)->optimize(); - ++it; + thenExpr = thenExpr->optimize(); + ++i; } } // Erasing the rest of the cases because found a default true value. - if (true_const) { - _branches.erase(it, _branches.end()); + if (trueConst) { + while (i < numBranches()) { + deleteBranch(i); + } } // If there are no cases, make the switch its default. - if (_branches.size() == 0 && _default) { - return _default; - } else if (_branches.size() == 0) { + if (numBranches() == 0) { uassert(40069, - "One cannot execute a switch statement where all the cases evaluate to false " - "without a default.", - _branches.size()); + "Cannot execute a switch statement where all the cases evaluate to false " + "without a default", + defaultExpr()); + return _children.back(); } return this; @@ -5743,17 +5742,18 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() { Value ExpressionSwitch::serialize(bool explain) const { std::vector<Value> serializedBranches; - serializedBranches.reserve(_branches.size()); + serializedBranches.reserve(numBranches()); - for (auto&& branch : _branches) { - serializedBranches.push_back(Value(Document{{"case", branch.first->serialize(explain)}, - {"then", branch.second->serialize(explain)}})); + for (int i = 0; i < numBranches(); ++i) { + auto [caseExpr, thenExpr] = getBranch(i); + serializedBranches.push_back(Value(Document{{"case", caseExpr->serialize(explain)}, + {"then", thenExpr->serialize(explain)}})); } - if (_default) { + if (defaultExpr()) { return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}, - {"default", _default->serialize(explain)}}}}); + {"default", defaultExpr()->serialize(explain)}}}}); } return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}}); diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 4680181b295..0c4e077511d 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -3283,11 +3283,10 @@ public: std::pair<boost::intrusive_ptr<Expression>&, boost::intrusive_ptr<Expression>&>; ExpressionSwitch(ExpressionContext* const expCtx, - std::vector<boost::intrusive_ptr<Expression>> children, - std::vector<ExpressionPair> branches) - : Expression(expCtx, std::move(children)), - _default(_children.back()), - _branches(std::move(branches)) {} + std::vector<boost::intrusive_ptr<Expression>> children) + : Expression(expCtx, std::move(children)) { + uassert(40068, "$switch requires at least one branch", numBranches() >= 1); + } Value evaluate(const Document& root, Variables* variables) const final; boost::intrusive_ptr<Expression> optimize() final; @@ -3304,9 +3303,35 @@ public: return visitor->visit(this); } + /** + * Returns the number of cases in the switch expression. Each branch is made up of two + * expressions ('case' and 'then'). + */ + int numBranches() const { + return _children.size() / 2; + } + + /** + * Returns a pair of expression pointers representing the 'case' and 'then' expressions for the + * i-th branch of the switch. + */ + std::pair<const Expression*, const Expression*> getBranch(int i) const { + invariant(i >= 0); + invariant(i < numBranches()); + return {_children[i * 2].get(), _children[i * 2 + 1].get()}; + } + + /** + * Returns the 'default' expression, or nullptr if there is no 'default'. + */ + const Expression* defaultExpr() const { + return _children.back().get(); + } + private: - boost::intrusive_ptr<Expression>& _default; - std::vector<ExpressionPair> _branches; + // Helper for 'optimize()'. Deletes the 'case' and 'then' children associated with the i-th + // branch of the switch. + void deleteBranch(int i); }; diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 2a9978e1bba..3e1592b12e1 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -1122,7 +1122,7 @@ TEST(ExpressionSwitch, ExpressionSwitchWithAllConstantFalsesAndNoDefaultErrors) ASSERT_THROWS_CODE(switchExp->optimize(), AssertionException, 40069); } -TEST(ExpressionSwitch, ExpressionSwitchWithZeroAsConstantFalsesAndNoDefaulErrors) { +TEST(ExpressionSwitch, ExpressionSwitchWithZeroAsConstantFalseAndNoDefaultErrors) { auto expCtx = ExpressionContextForTest{}; VariablesParseState vps = expCtx.variablesParseState; @@ -1234,6 +1234,62 @@ TEST(ExpressionSwitch, ExpressionSwitchWithNoConstantsShouldStayTheSame) { ASSERT_BSONOBJ_BINARY_EQ(switchQ, expressionToBson(optimizedStaySame)); } +// This test was designed to provide coverage for SERVER-70190, a bug in which optimizing a $switch +// expression could leave its children vector in a bad state. By walking the tree after optimizing +// we make sure that the expected children are found. +TEST(ExpressionSwitch, CaseEliminationShouldLeaveTreeInWalkableState) { + auto expCtx = ExpressionContextForTest{}; + VariablesParseState vps = expCtx.variablesParseState; + + BSONObj switchQ = fromjson(R"( + {$switch: { + branches: [ + {case: false, then: {$const: 0}}, + {case: "$z", then: {$const: 1}}, + {case: "$y", then: {$const: 3}}, + {case: true, then: {$const: 4}}, + {case: "$a", then: {$const: 5}}, + {case: "$b", then: {$const: 6}}, + {case: "$c", then: {$const: 7}} + ], + default: {$const: 8} + }} + )"); + auto switchExp = ExpressionSwitch::parse(&expCtx, switchQ.firstElement(), vps); + auto optimizedExpr = switchExp->optimize(); + + BSONObj optimizedQ = fromjson(R"( + {$switch: { + branches: [ + {case: "$z", then: {$const: 1}}, + {case: "$y", then: {$const: 3}} + ], + default: {$const: 4} + }} + )"); + + ASSERT_BSONOBJ_BINARY_EQ(optimizedQ, expressionToBson(optimizedExpr)); + + // Make sure that the expression tree appears as expected when the children are traversed using + // a for-each loop. + int childNum = 0; + int numConstants = 0; + for (auto&& child : optimizedExpr->getChildren()) { + // Children 0 and 2 are field path expressions, whereas 1, 3, and 4 are constants. + auto constExpr = dynamic_cast<ExpressionConstant*>(child.get()); + if (constExpr) { + ASSERT_VALUE_EQ(constExpr->getValue(), Value{childNum}); + ++numConstants; + } else { + ASSERT(dynamic_cast<ExpressionFieldPath*>(child.get())); + } + ++childNum; + } + // We should have seen 5 children total, 3 of which are constants. + ASSERT_EQ(childNum, 5); + ASSERT_EQ(numConstants, 3); +} + TEST(ExpressionArray, ExpressionArrayShouldOptimizeSubExpressionToExpressionConstant) { auto expCtx = ExpressionContextForTest{}; VariablesParseState vps = expCtx.variablesParseState; @@ -4175,7 +4231,7 @@ TEST(ExpressionEncryptedBetweenTest, ParseRoundTrip) { auto serializedExpr = expr->serialize(false); auto expectedExpr = fromjson(R"({$encryptedBetween: [ {$const: "age"}, - { $const: + { $const: {$binary: { "base64": "ZW5jcnlwdGVkIHBheWxvYWQ=", "subType": "6" |