diff options
author | Neil Shweky <neilshweky@gmail.com> | 2021-10-29 20:49:04 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-11-15 15:45:37 +0000 |
commit | ca9bb0300c804617e936c2e2516b441a9474e355 (patch) | |
tree | 0dbe82b14166d238bb7375f113b073d5b7662c64 | |
parent | fe9052cee98bcff412a06c3e8eaf9e54ea82a14c (diff) | |
download | mongo-ca9bb0300c804617e936c2e2516b441a9474e355.tar.gz |
SERVER-29425 implement $sortArray in classic engine
-rw-r--r-- | jstests/aggregation/expressions/sortArray.js | 193 | ||||
-rw-r--r-- | src/mongo/db/exec/document_value/value.cpp | 6 | ||||
-rw-r--r-- | src/mongo/db/exec/document_value/value.h | 3 | ||||
-rw-r--r-- | src/mongo/db/pipeline/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 102 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 36 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 85 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_visitor.h | 3 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 7 | ||||
-rw-r--r-- | src/mongo/db/update/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/update/pattern_cmp.h | 171 | ||||
-rw-r--r-- | src/mongo/db/update/pattern_cmp_test.cpp (renamed from src/mongo/db/update/push_sorter_test.cpp) | 2 | ||||
-rw-r--r-- | src/mongo/db/update/push_node.cpp | 36 | ||||
-rw-r--r-- | src/mongo/db/update/push_node.h | 2 | ||||
-rw-r--r-- | src/mongo/db/update/push_sorter.h | 75 |
15 files changed, 611 insertions, 113 deletions
diff --git a/jstests/aggregation/expressions/sortArray.js b/jstests/aggregation/expressions/sortArray.js new file mode 100644 index 00000000000..6f1db4449dd --- /dev/null +++ b/jstests/aggregation/expressions/sortArray.js @@ -0,0 +1,193 @@ +// SERVER-29425 added a new expression, $sortArray, which consumes an array or a nullish value +// and produces either the sorted version of that array, or null. In this test file, we check the +// behavior and error cases. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. +load("jstests/aggregation/extras/utils.js"); // For assertErrorCode. + +(function() { +"use strict"; + +let coll = db.sortArray; +coll.drop(); + +assert.commandWorked(coll.insert({ + nullField: null, + undefField: undefined, + embedded: [[1, 2], [3, 4]], + singleElem: [1], + normal: [1, 2, 3], + num: 1, + empty: [], + normalSingleObjs: [{a: 1}, {a: 2}, {a: 3}], + mismatchedSingleObjs: [{a: 1}, {b: 2}, {c: 3}], + normalMultiObjs: [{a: 1, b: 3, c: 1}, {a: 2, b: 2, c: 2}, {a: 3, b: 1, c: 3}], + tiesMultiObjs: [{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 4}, {a: 1, b: 3, c: 5}], + nestedObjs: [{a: 1, b: {c: 2}}, {a: 1, b: {c: 1}}], + mismatchedTypes: [1, [1], {a: 1}, "1"], + moreMismatchedTypes: [2, 1, "hello", {a: 6}, {a: "hello"}, {a: -1}, null], + mismatchedNumberTypes: [[NumberDecimal(4)], [1, 9, 8]], + +})); + +let assertDBOutputEquals = (expected, output) => { + output = output.toArray(); + assert.eq(1, output.length); + assert.eq(expected, output[0].sorted); +}; + +const isSortArrayEnabled = + db.adminCommand({getParameter: 1, featureFlagSortArray: 1}).featureFlagSortArray.value; + +if (!isSortArrayEnabled) { + // Verify that $sortArray cannot be used if the feature flag is set to false and ignore the + // rest of the test. + assert.commandFailedWithCode(coll.runCommand("aggregate", { + pipeline: [{$project: {sorted: {$sortArray: {sortBy: 1, input: {$literal: [1, 2]}}}}}], + cursor: {} + }), + 31325); + return; +} + +assertErrorCode(coll, [{$project: {sorted: {$sortArray: 1}}}], 2942500); +assertErrorCode(coll, [{$project: {sorted: {$sortArray: "$num"}}}], 2942500); + +assertDBOutputEquals([1, 2, 3], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: {$literal: [1, 2, 3]}, sortBy: 1}}}} +])); + +assertDBOutputEquals( + [3, 2, 1], + coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normal", sortBy: -1}}}}])); + +assertDBOutputEquals([1, 2, 3], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: {$literal: [3, 2, 1]}, sortBy: 1}}}} +])); + +assertDBOutputEquals([3, 2, 1], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: {$literal: [3, 2, 1]}, sortBy: -1}}}} +])); + +assertDBOutputEquals( + null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$notAField", sortBy: 1}}}}])); + +assertDBOutputEquals([[1, 2], [3, 4]], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: {$literal: [[1, 2], [3, 4]]}, sortBy: 1}}}} +])); + +assertDBOutputEquals( + [[3, 4], [1, 2]], + coll.aggregate([{$project: {sorted: {$sortArray: {input: "$embedded", sortBy: -1}}}}])); + +assertDBOutputEquals( + null, + coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: null}, sortBy: 1}}}}])); + +assertDBOutputEquals( + null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$nullField", sortBy: -1}}}}])); + +assertDBOutputEquals(null, coll.aggregate([ + {$project: {sorted: {$sortArray: {input: {$literal: undefined}, sortBy: 1}}}} +])); + +assertDBOutputEquals( + null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$undefField", sortBy: -1}}}}])); + +assertDBOutputEquals( + [1], coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: [1]}, sortBy: 1}}}}])); + +assertDBOutputEquals( + [1], coll.aggregate([{$project: {sorted: {$sortArray: {input: "$singleElem", sortBy: -1}}}}])); + +assertDBOutputEquals( + [], coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: []}, sortBy: 1}}}}])); + +assertDBOutputEquals( + [], coll.aggregate([{$project: {sorted: {$sortArray: {input: "$empty", sortBy: -1}}}}])); + +/* ------------------------ Object Array Tests ------------------------ */ + +assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {a: 1}}}}} +])); + +assertDBOutputEquals([{a: 3}, {a: 2}, {a: 1}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {a: -1}}}}} +])); + +assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {b: 1}}}}} +])); + +assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {b: -1}}}}} +])); + +assertDBOutputEquals( + [{a: 1}, {a: 2}, {a: 3}], + coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: 1}}}}])); + +assertDBOutputEquals( + [{a: 3}, {a: 2}, {a: 1}], + coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: -1}}}}])); + +assertDBOutputEquals([{a: 1}, {b: 2}, {c: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: 1}}}} +])); + +assertDBOutputEquals([{c: 3}, {b: 2}, {a: 1}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: -1}}}} +])); + +assertDBOutputEquals([{b: 2}, {c: 3}, {a: 1}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {a: 1}}}}} +])); + +assertDBOutputEquals([{a: 1}, {c: 3}, {b: 2}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {b: 1}}}}} +])); + +assertDBOutputEquals([{a: 1}, {b: 2}, {c: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {c: 1}}}}} +])); + +assertDBOutputEquals([{a: 3, b: 1, c: 3}, {a: 2, b: 2, c: 2}, {a: 1, b: 3, c: 1}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalMultiObjs", sortBy: {b: 1, a: 1}}}}} +])); + +assertDBOutputEquals([{a: 1, b: 3, c: 1}, {a: 2, b: 2, c: 2}, {a: 3, b: 1, c: 3}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$normalMultiObjs", sortBy: {b: -1, a: 1}}}}} +])); + +assertDBOutputEquals([{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 4}, {a: 1, b: 3, c: 5}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$tiesMultiObjs", sortBy: {a: 1, b: 1, c: 1}}}}} +])); + +assertDBOutputEquals([{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 5}, {a: 1, b: 3, c: 4}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$tiesMultiObjs", sortBy: {a: 1, b: 1, c: -1}}}}} +])); + +/* ------------------------ Nested Objects Tests ------------------------ */ + +assertDBOutputEquals([{a: 1, b: {c: 1}}, {a: 1, b: {c: 2}}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$nestedObjs", sortBy: {"b.c": 1}}}}} +])); + +assertDBOutputEquals([{a: 1, b: {c: 2}}, {a: 1, b: {c: 1}}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$nestedObjs", sortBy: {"b.c": -1}}}}} +])); + +/* ------------------------ Mismatched Types Tests ------------------------ */ + +assertDBOutputEquals( + [1, "1", {a: 1}, [1]], + coll.aggregate([{$project: {sorted: {$sortArray: {input: "$mismatchedTypes", sortBy: 1}}}}])); + +assertDBOutputEquals([null, 1, 2, "hello", {a: -1}, {a: 6}, {a: "hello"}], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$moreMismatchedTypes", sortBy: 1}}}} +])); + +assertDBOutputEquals([[1, 9, 8], [NumberDecimal(4)]], coll.aggregate([ + {$project: {sorted: {$sortArray: {input: "$mismatchedNumberTypes", sortBy: 1}}}} +])); +}()); diff --git a/src/mongo/db/exec/document_value/value.cpp b/src/mongo/db/exec/document_value/value.cpp index 08409bd2461..248514180f0 100644 --- a/src/mongo/db/exec/document_value/value.cpp +++ b/src/mongo/db/exec/document_value/value.cpp @@ -1446,4 +1446,10 @@ Value Value::deserializeForIDL(const BSONElement& element) { return Value(element); } +BSONObj Value::wrap(StringData newName) const { + BSONObjBuilder b(getApproximateSize() + 6 + newName.size()); + addToBsonObj(&b, newName); + return b.obj(); +} + } // namespace mongo diff --git a/src/mongo/db/exec/document_value/value.h b/src/mongo/db/exec/document_value/value.h index 42c7b1c0ba9..bf54221326c 100644 --- a/src/mongo/db/exec/document_value/value.h +++ b/src/mongo/db/exec/document_value/value.h @@ -394,6 +394,9 @@ public: void serializeForIDL(BSONArrayBuilder* builder) const; static Value deserializeForIDL(const BSONElement& element); + // Wrap a value in a BSONObj. + BSONObj wrap(StringData newName) const; + private: explicit Value(const ValueStorage& storage) : _storage(storage) {} diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index bbf02c42a2e..426269e759a 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -106,6 +106,7 @@ env.Library( 'variables.cpp', ], LIBDEPS=[ + '$BUILD_DIR/mongo/db/bson/dotted_path_support', '$BUILD_DIR/mongo/db/commands/test_commands_enabled', '$BUILD_DIR/mongo/db/exec/document_value/document_value', '$BUILD_DIR/mongo/db/query/collation/collator_factory_interface', diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index d34784b1b34..f92c74cfbc2 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -4548,6 +4548,108 @@ ValueUnorderedSet arrayToUnorderedSet(const Value& val, const ValueComparator& v } } // namespace +/* ------------------------ ExpressionSortArray ------------------------ */ + +namespace { + +BSONObj createSortSpecObject(const BSONElement& sortClause) { + if (sortClause.type() == BSONType::Object) { + auto status = pattern_cmp::checkSortClause(sortClause.embeddedObject()); + uassert(2942505, status.toString(), status.isOK()); + + return sortClause.embeddedObject(); + } else if (sortClause.isNumber()) { + double orderVal = sortClause.Number(); + uassert(2942506, + "The $sort element value must be either 1 or -1", + orderVal == -1 || orderVal == 1); + + return BSON("" << orderVal); + } else { + uasserted(2942507, + "The $sort is invalid: use 1/-1 to sort the whole element, or {field:1/-1} to " + "sort embedded fields"); + } +} + +} // namespace + +intrusive_ptr<Expression> ExpressionSortArray::parse(ExpressionContext* const expCtx, + BSONElement expr, + const VariablesParseState& vps) { + uassert(2942500, + str::stream() << "$sortArray requires an object as an argument, found: " + << typeName(expr.type()), + expr.type() == Object); + + boost::intrusive_ptr<Expression> input; + boost::optional<PatternValueCmp> sortBy; + for (auto&& elem : expr.Obj()) { + auto field = elem.fieldNameStringData(); + + if (field == "input") { + input = parseOperand(expCtx, elem, vps); + } else if (field == "sortBy") { + sortBy = PatternValueCmp(createSortSpecObject(elem), elem, expCtx->getCollator()); + } else { + uasserted(2942501, str::stream() << "$sortArray found an unknown argument: " << field); + } + } + + uassert(2942502, "$sortArray requires 'input' to be specified", input); + uassert(2942503, "$sortArray requires 'sortBy' to be specified", sortBy != boost::none); + + return new ExpressionSortArray(expCtx, std::move(input), *sortBy); +} + +Value ExpressionSortArray::evaluate(const Document& root, Variables* variables) const { + Value input(_input->evaluate(root, variables)); + + if (input.nullish()) { + return Value(BSONNULL); + } + + uassert(2942504, + str::stream() << "The input argument to $sortArray must be an array, but was of type: " + << typeName(input.getType()), + input.isArray()); + + if (input.getArrayLength() < 2) { + return input; + } + + std::vector<Value> array = input.getArray(); + std::sort(array.begin(), array.end(), _sortBy); + return Value(array); +} + +// TODO: SERVER-60207 change this when we enable the feature flag by default. +REGISTER_EXPRESSION_CONDITIONALLY(sortArray, + ExpressionSortArray::parse, + AllowedWithApiStrict::kNeverInVersion1, + AllowedWithClientType::kAny, + boost::none, + feature_flags::gFeatureFlagSortArray.isEnabledAndIgnoreFCV()); + +const char* ExpressionSortArray::getOpName() const { + return kName.rawData(); +} + +intrusive_ptr<Expression> ExpressionSortArray::optimize() { + _input = _input->optimize(); + return this; +} + +void ExpressionSortArray::_doAddDependencies(DepsTracker* deps) const { + _input->addDependencies(deps); +} + +Value ExpressionSortArray::serialize(bool explain) const { + return Value(Document{{kName, + Document{{"input", _input->serialize(explain)}, + {"sortBy", _sortBy.getOriginalElement()}}}}); +} + /* ----------------------- ExpressionSetDifference ---------------------------- */ Value ExpressionSetDifference::evaluate(const Document& root, Variables* variables) const { diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 9bbca5ce58b..f7d751a79fd 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -54,6 +54,7 @@ #include "mongo/db/query/query_feature_flags_gen.h" #include "mongo/db/query/sort_pattern.h" #include "mongo/db/server_options.h" +#include "mongo/db/update/pattern_cmp.h" #include "mongo/util/intrusive_counter.h" #include "mongo/util/str.h" @@ -2889,6 +2890,41 @@ public: } }; +class ExpressionSortArray final : public Expression { +public: + static constexpr auto kName = "$sortArray"_sd; + ExpressionSortArray(ExpressionContext* const expCtx, + boost::intrusive_ptr<Expression> input, + const PatternValueCmp& sortBy) + : Expression(expCtx, {std::move(input)}), _input(_children[0]), _sortBy(sortBy) { + expCtx->sbeCompatible = false; + } + + Value evaluate(const Document& root, Variables* variables) const final; + boost::intrusive_ptr<Expression> optimize() final; + static boost::intrusive_ptr<Expression> parse(ExpressionContext* expCtx, + BSONElement expr, + const VariablesParseState& vps); + Value serialize(bool explain) const final; + + void acceptVisitor(ExpressionMutableVisitor* visitor) final { + return visitor->visit(this); + } + + void acceptVisitor(ExpressionConstVisitor* visitor) const final { + return visitor->visit(this); + } + + const char* getOpName() const; + +protected: + void _doAddDependencies(DepsTracker* deps) const final; + +private: + boost::intrusive_ptr<Expression>& _input; + + PatternValueCmp _sortBy; +}; class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3> { public: diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 3f6b6c73438..670326be083 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -262,6 +262,91 @@ TEST(ExpressionReverseArrayTest, ReturnsNullWithNullishInput) { {{{Value(BSONNULL)}, Value(BSONNULL)}, {{Value(BSONUndefined)}, Value(BSONNULL)}}); } +/* ------------------------ ExpressionSortArray -------------------- */ + +TEST(ExpressionSortArrayTest, SortsNormalArrayForwards) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 2, 1, 3 ] }, sortBy: 1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_EQ(val.getType(), BSONType::Array); + ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(1 << 2 << 3))); +} + + +TEST(ExpressionSortArrayTest, SortsNormalArrayBackwards) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 2, 1, 3 ] }, sortBy: -1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_EQ(val.getType(), BSONType::Array); + ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(3 << 2 << 1))); +} + +TEST(ExpressionSortArrayTest, SortsEmptyArray) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ ] }, sortBy: -1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_EQ(val.getType(), BSONType::Array); + ASSERT_VALUE_EQ(val, Value(std::vector<Value>())); +} + +TEST(ExpressionSortArrayTest, SortsOneElementArray) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 1 ] }, sortBy: -1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_EQ(val.getType(), BSONType::Array); + ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(1))); +} + +TEST(ExpressionSortArrayTest, ReturnsNullWithNullInput) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: null }, sortBy: -1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_VALUE_EQ(val, Value(BSONNULL)); +} + +TEST(ExpressionSortArrayTest, ReturnsNullWithUndefinedInput) { + RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true}; + + auto expCtx = ExpressionContextForTest{}; + BSONObj expr = fromjson("{ $sortArray: { input: { $literal: undefined }, sortBy: -1 } }"); + + auto expressionSortArray = + ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState); + Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables); + + ASSERT_VALUE_EQ(val, Value(BSONNULL)); +} + /* ------------------------- Old-style tests -------------------------- */ namespace Add { diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h index 7db0d760ee6..7961c549c16 100644 --- a/src/mongo/db/pipeline/expression_visitor.h +++ b/src/mongo/db/pipeline/expression_visitor.h @@ -104,6 +104,7 @@ class ExpressionSetIsSubset; class ExpressionSetUnion; class ExpressionSize; class ExpressionReverseArray; +class ExpressionSortArray; class ExpressionSlice; class ExpressionIsArray; class ExpressionRandom; @@ -262,6 +263,7 @@ public: virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSetUnion>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSize>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionReverseArray>) = 0; + virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSortArray>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSlice>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionIsArray>) = 0; virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionRandom>) = 0; @@ -438,6 +440,7 @@ struct SelectiveConstExpressionVisitorBase : public ExpressionConstVisitor { void visit(const ExpressionSetUnion*) override {} void visit(const ExpressionSize*) override {} void visit(const ExpressionReverseArray*) override {} + void visit(const ExpressionSortArray*) override {} void visit(const ExpressionSlice*) override {} void visit(const ExpressionIsArray*) override {} void visit(const ExpressionRound*) override {} diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 5a7d2384eb5..46554b8c094 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -387,6 +387,7 @@ public: void visit(const ExpressionSetUnion* expr) final {} void visit(const ExpressionSize* expr) final {} void visit(const ExpressionReverseArray* expr) final {} + void visit(const ExpressionSortArray* expr) final {} void visit(const ExpressionSlice* expr) final {} void visit(const ExpressionIsArray* expr) final {} void visit(const ExpressionRound* expr) final {} @@ -617,6 +618,7 @@ public: void visit(const ExpressionSetUnion* expr) final {} void visit(const ExpressionSize* expr) final {} void visit(const ExpressionReverseArray* expr) final {} + void visit(const ExpressionSortArray* expr) final {} void visit(const ExpressionSlice* expr) final {} void visit(const ExpressionIsArray* expr) final {} void visit(const ExpressionRound* expr) final {} @@ -2572,6 +2574,11 @@ public: _context->pushExpr( sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(exprRevArr))); } + + void visit(const ExpressionSortArray* expr) final { + unsupportedExpression(expr->getOpName()); + } + void visit(const ExpressionSlice* expr) final { unsupportedExpression(expr->getOpName()); } diff --git a/src/mongo/db/update/SConscript b/src/mongo/db/update/SConscript index a8f0bd1dc2b..3e0eb55927f 100644 --- a/src/mongo/db/update/SConscript +++ b/src/mongo/db/update/SConscript @@ -133,7 +133,7 @@ env.CppUnitTest( 'pull_node_test.cpp', 'pullall_node_test.cpp', 'push_node_test.cpp', - 'push_sorter_test.cpp', + 'pattern_cmp_test.cpp', 'rename_node_test.cpp', 'set_node_test.cpp', 'unset_node_test.cpp', diff --git a/src/mongo/db/update/pattern_cmp.h b/src/mongo/db/update/pattern_cmp.h new file mode 100644 index 00000000000..c612bdec29f --- /dev/null +++ b/src/mongo/db/update/pattern_cmp.h @@ -0,0 +1,171 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/bson/mutable/document.h" +#include "mongo/bson/mutable/element.h" +#include "mongo/db/bson/dotted_path_support.h" +#include "mongo/db/exec/document_value/document.h" +#include "mongo/db/exec/document_value/value_comparator.h" +#include "mongo/db/field_ref.h" +#include "mongo/db/jsobj.h" +#include "mongo/db/query/collation/collator_interface.h" + +namespace mongo { + +namespace pattern_cmp { + +/** + * When we include a sort specification, that object should pass the checks in + * this function. + * + * Checks include: + * 1. The sort pattern cannot be empty. + * 2. The value of each pattern element is 1 or -1. + * 3. The sort field cannot be empty. + * 4. If the sort field is a dotted field, it does not have any empty parts. + */ +static Status checkSortClause(const BSONObj& sortObject) { + if (sortObject.isEmpty()) { + return Status(ErrorCodes::BadValue, + "The sort pattern is empty when it should be a set of fields."); + } + + for (auto&& patternElement : sortObject) { + double orderVal = patternElement.isNumber() ? patternElement.Number() : 0; + if (orderVal != -1 && orderVal != 1) { + return Status(ErrorCodes::BadValue, "The sort element value must be either 1 or -1"); + } + + FieldRef sortField(patternElement.fieldName()); + if (sortField.numParts() == 0) { + return Status(ErrorCodes::BadValue, "The sort field cannot be empty"); + } + + for (size_t i = 0; i < sortField.numParts(); ++i) { + if (sortField.getPart(i).size() == 0) { + return Status(ErrorCodes::BadValue, + str::stream() << "The sort field is a dotted field " + "but has an empty part: " + << sortField.dottedField()); + } + } + } + + return Status::OK(); +} + +} // namespace pattern_cmp + +// Extracts the value for 'pattern' for both 'lhs' and 'rhs' and return true if 'lhs' < +// 'rhs'. We expect that both 'lhs' and 'rhs' be key patterns. +class PatternElementCmp { +public: + PatternElementCmp() = default; + + PatternElementCmp(const BSONObj& pattern, const CollatorInterface* collator) + : sortPattern(pattern.copy()), + useWholeValue(sortPattern.hasField("")), + collator(collator) {} + + bool operator()(const mutablebson::Element& lhs, const mutablebson::Element& rhs) const { + namespace dps = ::mongo::dotted_path_support; + if (useWholeValue) { + const int comparedValue = lhs.compareWithElement(rhs, collator, false); + + const bool reversed = (sortPattern.firstElement().number() < 0); + + return (reversed ? comparedValue > 0 : comparedValue < 0); + } else { + BSONObj lhsObj = + lhs.getType() == Object ? lhs.getValueObject() : lhs.getValue().wrap(""); + BSONObj rhsObj = + rhs.getType() == Object ? rhs.getValueObject() : rhs.getValue().wrap(""); + + BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true); + BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true); + + return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0; + } + } + + BSONObj sortPattern; + bool useWholeValue = true; + const CollatorInterface* collator = nullptr; +}; + +class PatternValueCmp { +public: + PatternValueCmp() = default; + + PatternValueCmp(const BSONObj& pattern, + const BSONElement& originalElement, + const CollatorInterface* collator) + : sortPattern(pattern.copy()), + useWholeValue(sortPattern.hasField("")), + originalObj(BSONObj().addField(originalElement).copy()), + collator(collator) {} + + bool operator()(const Value& lhs, const Value& rhs) const { + namespace dps = ::mongo::dotted_path_support; + if (useWholeValue) { + const bool ascending = ValueComparator().getLessThan()(lhs, rhs); + + const bool reversed = (sortPattern.firstElement().number() < 0); + + return (reversed ? !ascending : ascending); + } else { + BSONObj lhsObj = lhs.isObject() ? lhs.getDocument().toBson() : lhs.wrap(""); + BSONObj rhsObj = rhs.isObject() ? rhs.getDocument().toBson() : rhs.wrap(""); + + BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true); + BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true); + + return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0; + } + } + + // Returns the original element passed into the PatternValueCmp constructor. + BSONElement getOriginalElement() const { + return originalObj.firstElement(); + } + + BSONObj sortPattern; + bool useWholeValue = true; + + /** + * We store the original element as an object so that we can call the copy() method on it. + * This way, the PatternValueCmp class can have its own copy of the object. + */ + BSONObj originalObj; + const CollatorInterface* collator = nullptr; +}; + +} // namespace mongo diff --git a/src/mongo/db/update/push_sorter_test.cpp b/src/mongo/db/update/pattern_cmp_test.cpp index 6dfa3219dd8..a6b0f44aadd 100644 --- a/src/mongo/db/update/push_sorter_test.cpp +++ b/src/mongo/db/update/pattern_cmp_test.cpp @@ -27,7 +27,7 @@ * it in the license file. */ -#include "mongo/db/update/push_sorter.h" +#include "mongo/db/update/pattern_cmp.h" #include "mongo/bson/mutable/algorithm.h" #include "mongo/bson/mutable/document.h" diff --git a/src/mongo/db/update/push_node.cpp b/src/mongo/db/update/push_node.cpp index f428269a924..3a4e7c1705c 100644 --- a/src/mongo/db/update/push_node.cpp +++ b/src/mongo/db/update/push_node.cpp @@ -49,40 +49,6 @@ const StringData PushNode::kPositionClauseName = "$position"; namespace { /** - * When the $sort clause in a $push modifer is an object, that object should pass the checks in - * this function. - */ -Status checkSortClause(const BSONObj& sortObject) { - if (sortObject.isEmpty()) { - return Status(ErrorCodes::BadValue, - "The $sort pattern is empty when it should be a set of fields."); - } - - for (auto&& patternElement : sortObject) { - double orderVal = patternElement.isNumber() ? patternElement.Number() : 0; - if (orderVal != -1 && orderVal != 1) { - return Status(ErrorCodes::BadValue, "The $sort element value must be either 1 or -1"); - } - - FieldRef sortField(patternElement.fieldName()); - if (sortField.numParts() == 0) { - return Status(ErrorCodes::BadValue, "The $sort field cannot be empty"); - } - - for (size_t i = 0; i < sortField.numParts(); ++i) { - if (sortField.getPart(i).size() == 0) { - return Status(ErrorCodes::BadValue, - str::stream() << "The $sort field is a dotted field " - "but has an empty part: " - << sortField.dottedField()); - } - } - } - - return Status::OK(); -} - -/** * std::abs(LLONG_MIN) results in undefined behavior on 2's complement systems because the * absolute value of LLONG_MIN cannot be represented in a 'long long'. * @@ -157,7 +123,7 @@ Status PushNode::init(BSONElement modExpr, const boost::intrusive_ptr<Expression auto sortClause = sortIt->second; if (sortClause.type() == BSONType::Object) { - auto status = checkSortClause(sortClause.embeddedObject()); + auto status = pattern_cmp::checkSortClause(sortClause.embeddedObject()); if (status.isOK()) { _sort = PatternElementCmp(sortClause.embeddedObject(), expCtx->getCollator()); diff --git a/src/mongo/db/update/push_node.h b/src/mongo/db/update/push_node.h index 38b4b00bc66..77cc6b5d125 100644 --- a/src/mongo/db/update/push_node.h +++ b/src/mongo/db/update/push_node.h @@ -36,7 +36,7 @@ #include "mongo/base/string_data.h" #include "mongo/db/update/modifier_node.h" -#include "mongo/db/update/push_sorter.h" +#include "mongo/db/update/pattern_cmp.h" namespace mongo { diff --git a/src/mongo/db/update/push_sorter.h b/src/mongo/db/update/push_sorter.h deleted file mode 100644 index 139529533ae..00000000000 --- a/src/mongo/db/update/push_sorter.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright (C) 2018-present MongoDB, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the Server Side Public License, version 1, - * as published by MongoDB, Inc. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * Server Side Public License for more details. - * - * You should have received a copy of the Server Side Public License - * along with this program. If not, see - * <http://www.mongodb.com/licensing/server-side-public-license>. - * - * As a special exception, the copyright holders give permission to link the - * code of portions of this program with the OpenSSL library under certain - * conditions as described in each individual source file and distribute - * linked combinations including the program with the OpenSSL library. You - * must comply with the Server Side Public License in all respects for - * all of the code used other than as permitted herein. If you modify file(s) - * with this exception, you may extend this exception to your version of the - * file(s), but you are not obligated to do so. If you do not wish to do so, - * delete this exception statement from your version. If you delete this - * exception statement from all source files in the program, then also delete - * it in the license file. - */ - -#pragma once - -#include "mongo/bson/mutable/document.h" -#include "mongo/bson/mutable/element.h" -#include "mongo/db/bson/dotted_path_support.h" -#include "mongo/db/jsobj.h" -#include "mongo/db/query/collation/collator_interface.h" - -namespace mongo { - -// Extracts the value for 'pattern' for both 'lhs' and 'rhs' and return true if 'lhs' < -// 'rhs'. We expect that both 'lhs' and 'rhs' be key patterns. -struct PatternElementCmp { - PatternElementCmp() = default; - - PatternElementCmp(const BSONObj& pattern, const CollatorInterface* collator) - : sortPattern(pattern), useWholeValue(pattern.hasField("")), collator(collator) {} - - bool operator()(const mutablebson::Element& lhs, const mutablebson::Element& rhs) const { - namespace dps = ::mongo::dotted_path_support; - if (useWholeValue) { - const int comparedValue = lhs.compareWithElement(rhs, collator, false); - - const bool reversed = (sortPattern.firstElement().number() < 0); - - return (reversed ? comparedValue > 0 : comparedValue < 0); - } else { - // TODO: Push on to mutable in the future, and to support non-contiguous Elements. - BSONObj lhsObj = - lhs.getType() == Object ? lhs.getValueObject() : lhs.getValue().wrap(""); - BSONObj rhsObj = - rhs.getType() == Object ? rhs.getValueObject() : rhs.getValue().wrap(""); - - BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true); - BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true); - - return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0; - } - } - - BSONObj sortPattern; - bool useWholeValue = true; - const CollatorInterface* collator = nullptr; -}; - -} // namespace mongo |