summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNeil Shweky <neilshweky@gmail.com>2021-10-29 20:49:04 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-11-15 15:45:37 +0000
commitca9bb0300c804617e936c2e2516b441a9474e355 (patch)
tree0dbe82b14166d238bb7375f113b073d5b7662c64
parentfe9052cee98bcff412a06c3e8eaf9e54ea82a14c (diff)
downloadmongo-ca9bb0300c804617e936c2e2516b441a9474e355.tar.gz
SERVER-29425 implement $sortArray in classic engine
-rw-r--r--jstests/aggregation/expressions/sortArray.js193
-rw-r--r--src/mongo/db/exec/document_value/value.cpp6
-rw-r--r--src/mongo/db/exec/document_value/value.h3
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/expression.cpp102
-rw-r--r--src/mongo/db/pipeline/expression.h36
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp85
-rw-r--r--src/mongo/db/pipeline/expression_visitor.h3
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp7
-rw-r--r--src/mongo/db/update/SConscript2
-rw-r--r--src/mongo/db/update/pattern_cmp.h171
-rw-r--r--src/mongo/db/update/pattern_cmp_test.cpp (renamed from src/mongo/db/update/push_sorter_test.cpp)2
-rw-r--r--src/mongo/db/update/push_node.cpp36
-rw-r--r--src/mongo/db/update/push_node.h2
-rw-r--r--src/mongo/db/update/push_sorter.h75
15 files changed, 611 insertions, 113 deletions
diff --git a/jstests/aggregation/expressions/sortArray.js b/jstests/aggregation/expressions/sortArray.js
new file mode 100644
index 00000000000..6f1db4449dd
--- /dev/null
+++ b/jstests/aggregation/expressions/sortArray.js
@@ -0,0 +1,193 @@
+// SERVER-29425 added a new expression, $sortArray, which consumes an array or a nullish value
+// and produces either the sorted version of that array, or null. In this test file, we check the
+// behavior and error cases.
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
+load("jstests/aggregation/extras/utils.js"); // For assertErrorCode.
+
+(function() {
+"use strict";
+
+let coll = db.sortArray;
+coll.drop();
+
+assert.commandWorked(coll.insert({
+ nullField: null,
+ undefField: undefined,
+ embedded: [[1, 2], [3, 4]],
+ singleElem: [1],
+ normal: [1, 2, 3],
+ num: 1,
+ empty: [],
+ normalSingleObjs: [{a: 1}, {a: 2}, {a: 3}],
+ mismatchedSingleObjs: [{a: 1}, {b: 2}, {c: 3}],
+ normalMultiObjs: [{a: 1, b: 3, c: 1}, {a: 2, b: 2, c: 2}, {a: 3, b: 1, c: 3}],
+ tiesMultiObjs: [{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 4}, {a: 1, b: 3, c: 5}],
+ nestedObjs: [{a: 1, b: {c: 2}}, {a: 1, b: {c: 1}}],
+ mismatchedTypes: [1, [1], {a: 1}, "1"],
+ moreMismatchedTypes: [2, 1, "hello", {a: 6}, {a: "hello"}, {a: -1}, null],
+ mismatchedNumberTypes: [[NumberDecimal(4)], [1, 9, 8]],
+
+}));
+
+let assertDBOutputEquals = (expected, output) => {
+ output = output.toArray();
+ assert.eq(1, output.length);
+ assert.eq(expected, output[0].sorted);
+};
+
+const isSortArrayEnabled =
+ db.adminCommand({getParameter: 1, featureFlagSortArray: 1}).featureFlagSortArray.value;
+
+if (!isSortArrayEnabled) {
+ // Verify that $sortArray cannot be used if the feature flag is set to false and ignore the
+ // rest of the test.
+ assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{$project: {sorted: {$sortArray: {sortBy: 1, input: {$literal: [1, 2]}}}}}],
+ cursor: {}
+ }),
+ 31325);
+ return;
+}
+
+assertErrorCode(coll, [{$project: {sorted: {$sortArray: 1}}}], 2942500);
+assertErrorCode(coll, [{$project: {sorted: {$sortArray: "$num"}}}], 2942500);
+
+assertDBOutputEquals([1, 2, 3], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: {$literal: [1, 2, 3]}, sortBy: 1}}}}
+]));
+
+assertDBOutputEquals(
+ [3, 2, 1],
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normal", sortBy: -1}}}}]));
+
+assertDBOutputEquals([1, 2, 3], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: {$literal: [3, 2, 1]}, sortBy: 1}}}}
+]));
+
+assertDBOutputEquals([3, 2, 1], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: {$literal: [3, 2, 1]}, sortBy: -1}}}}
+]));
+
+assertDBOutputEquals(
+ null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$notAField", sortBy: 1}}}}]));
+
+assertDBOutputEquals([[1, 2], [3, 4]], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: {$literal: [[1, 2], [3, 4]]}, sortBy: 1}}}}
+]));
+
+assertDBOutputEquals(
+ [[3, 4], [1, 2]],
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: "$embedded", sortBy: -1}}}}]));
+
+assertDBOutputEquals(
+ null,
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: null}, sortBy: 1}}}}]));
+
+assertDBOutputEquals(
+ null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$nullField", sortBy: -1}}}}]));
+
+assertDBOutputEquals(null, coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: {$literal: undefined}, sortBy: 1}}}}
+]));
+
+assertDBOutputEquals(
+ null, coll.aggregate([{$project: {sorted: {$sortArray: {input: "$undefField", sortBy: -1}}}}]));
+
+assertDBOutputEquals(
+ [1], coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: [1]}, sortBy: 1}}}}]));
+
+assertDBOutputEquals(
+ [1], coll.aggregate([{$project: {sorted: {$sortArray: {input: "$singleElem", sortBy: -1}}}}]));
+
+assertDBOutputEquals(
+ [], coll.aggregate([{$project: {sorted: {$sortArray: {input: {$literal: []}, sortBy: 1}}}}]));
+
+assertDBOutputEquals(
+ [], coll.aggregate([{$project: {sorted: {$sortArray: {input: "$empty", sortBy: -1}}}}]));
+
+/* ------------------------ Object Array Tests ------------------------ */
+
+assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {a: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 3}, {a: 2}, {a: 1}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {a: -1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {b: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1}, {a: 2}, {a: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: {b: -1}}}}}
+]));
+
+assertDBOutputEquals(
+ [{a: 1}, {a: 2}, {a: 3}],
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: 1}}}}]));
+
+assertDBOutputEquals(
+ [{a: 3}, {a: 2}, {a: 1}],
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: "$normalSingleObjs", sortBy: -1}}}}]));
+
+assertDBOutputEquals([{a: 1}, {b: 2}, {c: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: 1}}}}
+]));
+
+assertDBOutputEquals([{c: 3}, {b: 2}, {a: 1}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: -1}}}}
+]));
+
+assertDBOutputEquals([{b: 2}, {c: 3}, {a: 1}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {a: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1}, {c: 3}, {b: 2}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {b: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1}, {b: 2}, {c: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedSingleObjs", sortBy: {c: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 3, b: 1, c: 3}, {a: 2, b: 2, c: 2}, {a: 1, b: 3, c: 1}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalMultiObjs", sortBy: {b: 1, a: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1, b: 3, c: 1}, {a: 2, b: 2, c: 2}, {a: 3, b: 1, c: 3}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$normalMultiObjs", sortBy: {b: -1, a: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 4}, {a: 1, b: 3, c: 5}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$tiesMultiObjs", sortBy: {a: 1, b: 1, c: 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1, b: 2, c: 1}, {a: 1, b: 3, c: 5}, {a: 1, b: 3, c: 4}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$tiesMultiObjs", sortBy: {a: 1, b: 1, c: -1}}}}}
+]));
+
+/* ------------------------ Nested Objects Tests ------------------------ */
+
+assertDBOutputEquals([{a: 1, b: {c: 1}}, {a: 1, b: {c: 2}}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$nestedObjs", sortBy: {"b.c": 1}}}}}
+]));
+
+assertDBOutputEquals([{a: 1, b: {c: 2}}, {a: 1, b: {c: 1}}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$nestedObjs", sortBy: {"b.c": -1}}}}}
+]));
+
+/* ------------------------ Mismatched Types Tests ------------------------ */
+
+assertDBOutputEquals(
+ [1, "1", {a: 1}, [1]],
+ coll.aggregate([{$project: {sorted: {$sortArray: {input: "$mismatchedTypes", sortBy: 1}}}}]));
+
+assertDBOutputEquals([null, 1, 2, "hello", {a: -1}, {a: 6}, {a: "hello"}], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$moreMismatchedTypes", sortBy: 1}}}}
+]));
+
+assertDBOutputEquals([[1, 9, 8], [NumberDecimal(4)]], coll.aggregate([
+ {$project: {sorted: {$sortArray: {input: "$mismatchedNumberTypes", sortBy: 1}}}}
+]));
+}());
diff --git a/src/mongo/db/exec/document_value/value.cpp b/src/mongo/db/exec/document_value/value.cpp
index 08409bd2461..248514180f0 100644
--- a/src/mongo/db/exec/document_value/value.cpp
+++ b/src/mongo/db/exec/document_value/value.cpp
@@ -1446,4 +1446,10 @@ Value Value::deserializeForIDL(const BSONElement& element) {
return Value(element);
}
+BSONObj Value::wrap(StringData newName) const {
+ BSONObjBuilder b(getApproximateSize() + 6 + newName.size());
+ addToBsonObj(&b, newName);
+ return b.obj();
+}
+
} // namespace mongo
diff --git a/src/mongo/db/exec/document_value/value.h b/src/mongo/db/exec/document_value/value.h
index 42c7b1c0ba9..bf54221326c 100644
--- a/src/mongo/db/exec/document_value/value.h
+++ b/src/mongo/db/exec/document_value/value.h
@@ -394,6 +394,9 @@ public:
void serializeForIDL(BSONArrayBuilder* builder) const;
static Value deserializeForIDL(const BSONElement& element);
+ // Wrap a value in a BSONObj.
+ BSONObj wrap(StringData newName) const;
+
private:
explicit Value(const ValueStorage& storage) : _storage(storage) {}
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index bbf02c42a2e..426269e759a 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -106,6 +106,7 @@ env.Library(
'variables.cpp',
],
LIBDEPS=[
+ '$BUILD_DIR/mongo/db/bson/dotted_path_support',
'$BUILD_DIR/mongo/db/commands/test_commands_enabled',
'$BUILD_DIR/mongo/db/exec/document_value/document_value',
'$BUILD_DIR/mongo/db/query/collation/collator_factory_interface',
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index d34784b1b34..f92c74cfbc2 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -4548,6 +4548,108 @@ ValueUnorderedSet arrayToUnorderedSet(const Value& val, const ValueComparator& v
}
} // namespace
+/* ------------------------ ExpressionSortArray ------------------------ */
+
+namespace {
+
+BSONObj createSortSpecObject(const BSONElement& sortClause) {
+ if (sortClause.type() == BSONType::Object) {
+ auto status = pattern_cmp::checkSortClause(sortClause.embeddedObject());
+ uassert(2942505, status.toString(), status.isOK());
+
+ return sortClause.embeddedObject();
+ } else if (sortClause.isNumber()) {
+ double orderVal = sortClause.Number();
+ uassert(2942506,
+ "The $sort element value must be either 1 or -1",
+ orderVal == -1 || orderVal == 1);
+
+ return BSON("" << orderVal);
+ } else {
+ uasserted(2942507,
+ "The $sort is invalid: use 1/-1 to sort the whole element, or {field:1/-1} to "
+ "sort embedded fields");
+ }
+}
+
+} // namespace
+
+intrusive_ptr<Expression> ExpressionSortArray::parse(ExpressionContext* const expCtx,
+ BSONElement expr,
+ const VariablesParseState& vps) {
+ uassert(2942500,
+ str::stream() << "$sortArray requires an object as an argument, found: "
+ << typeName(expr.type()),
+ expr.type() == Object);
+
+ boost::intrusive_ptr<Expression> input;
+ boost::optional<PatternValueCmp> sortBy;
+ for (auto&& elem : expr.Obj()) {
+ auto field = elem.fieldNameStringData();
+
+ if (field == "input") {
+ input = parseOperand(expCtx, elem, vps);
+ } else if (field == "sortBy") {
+ sortBy = PatternValueCmp(createSortSpecObject(elem), elem, expCtx->getCollator());
+ } else {
+ uasserted(2942501, str::stream() << "$sortArray found an unknown argument: " << field);
+ }
+ }
+
+ uassert(2942502, "$sortArray requires 'input' to be specified", input);
+ uassert(2942503, "$sortArray requires 'sortBy' to be specified", sortBy != boost::none);
+
+ return new ExpressionSortArray(expCtx, std::move(input), *sortBy);
+}
+
+Value ExpressionSortArray::evaluate(const Document& root, Variables* variables) const {
+ Value input(_input->evaluate(root, variables));
+
+ if (input.nullish()) {
+ return Value(BSONNULL);
+ }
+
+ uassert(2942504,
+ str::stream() << "The input argument to $sortArray must be an array, but was of type: "
+ << typeName(input.getType()),
+ input.isArray());
+
+ if (input.getArrayLength() < 2) {
+ return input;
+ }
+
+ std::vector<Value> array = input.getArray();
+ std::sort(array.begin(), array.end(), _sortBy);
+ return Value(array);
+}
+
+// TODO: SERVER-60207 change this when we enable the feature flag by default.
+REGISTER_EXPRESSION_CONDITIONALLY(sortArray,
+ ExpressionSortArray::parse,
+ AllowedWithApiStrict::kNeverInVersion1,
+ AllowedWithClientType::kAny,
+ boost::none,
+ feature_flags::gFeatureFlagSortArray.isEnabledAndIgnoreFCV());
+
+const char* ExpressionSortArray::getOpName() const {
+ return kName.rawData();
+}
+
+intrusive_ptr<Expression> ExpressionSortArray::optimize() {
+ _input = _input->optimize();
+ return this;
+}
+
+void ExpressionSortArray::_doAddDependencies(DepsTracker* deps) const {
+ _input->addDependencies(deps);
+}
+
+Value ExpressionSortArray::serialize(bool explain) const {
+ return Value(Document{{kName,
+ Document{{"input", _input->serialize(explain)},
+ {"sortBy", _sortBy.getOriginalElement()}}}});
+}
+
/* ----------------------- ExpressionSetDifference ---------------------------- */
Value ExpressionSetDifference::evaluate(const Document& root, Variables* variables) const {
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 9bbca5ce58b..f7d751a79fd 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -54,6 +54,7 @@
#include "mongo/db/query/query_feature_flags_gen.h"
#include "mongo/db/query/sort_pattern.h"
#include "mongo/db/server_options.h"
+#include "mongo/db/update/pattern_cmp.h"
#include "mongo/util/intrusive_counter.h"
#include "mongo/util/str.h"
@@ -2889,6 +2890,41 @@ public:
}
};
+class ExpressionSortArray final : public Expression {
+public:
+ static constexpr auto kName = "$sortArray"_sd;
+ ExpressionSortArray(ExpressionContext* const expCtx,
+ boost::intrusive_ptr<Expression> input,
+ const PatternValueCmp& sortBy)
+ : Expression(expCtx, {std::move(input)}), _input(_children[0]), _sortBy(sortBy) {
+ expCtx->sbeCompatible = false;
+ }
+
+ Value evaluate(const Document& root, Variables* variables) const final;
+ boost::intrusive_ptr<Expression> optimize() final;
+ static boost::intrusive_ptr<Expression> parse(ExpressionContext* expCtx,
+ BSONElement expr,
+ const VariablesParseState& vps);
+ Value serialize(bool explain) const final;
+
+ void acceptVisitor(ExpressionMutableVisitor* visitor) final {
+ return visitor->visit(this);
+ }
+
+ void acceptVisitor(ExpressionConstVisitor* visitor) const final {
+ return visitor->visit(this);
+ }
+
+ const char* getOpName() const;
+
+protected:
+ void _doAddDependencies(DepsTracker* deps) const final;
+
+private:
+ boost::intrusive_ptr<Expression>& _input;
+
+ PatternValueCmp _sortBy;
+};
class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3> {
public:
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 3f6b6c73438..670326be083 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -262,6 +262,91 @@ TEST(ExpressionReverseArrayTest, ReturnsNullWithNullishInput) {
{{{Value(BSONNULL)}, Value(BSONNULL)}, {{Value(BSONUndefined)}, Value(BSONNULL)}});
}
+/* ------------------------ ExpressionSortArray -------------------- */
+
+TEST(ExpressionSortArrayTest, SortsNormalArrayForwards) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 2, 1, 3 ] }, sortBy: 1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_EQ(val.getType(), BSONType::Array);
+ ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(1 << 2 << 3)));
+}
+
+
+TEST(ExpressionSortArrayTest, SortsNormalArrayBackwards) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 2, 1, 3 ] }, sortBy: -1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_EQ(val.getType(), BSONType::Array);
+ ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(3 << 2 << 1)));
+}
+
+TEST(ExpressionSortArrayTest, SortsEmptyArray) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ ] }, sortBy: -1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_EQ(val.getType(), BSONType::Array);
+ ASSERT_VALUE_EQ(val, Value(std::vector<Value>()));
+}
+
+TEST(ExpressionSortArrayTest, SortsOneElementArray) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: [ 1 ] }, sortBy: -1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_EQ(val.getType(), BSONType::Array);
+ ASSERT_VALUE_EQ(val, Value(BSON_ARRAY(1)));
+}
+
+TEST(ExpressionSortArrayTest, ReturnsNullWithNullInput) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: null }, sortBy: -1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_VALUE_EQ(val, Value(BSONNULL));
+}
+
+TEST(ExpressionSortArrayTest, ReturnsNullWithUndefinedInput) {
+ RAIIServerParameterControllerForTest _controller{"featureFlagSortArray", true};
+
+ auto expCtx = ExpressionContextForTest{};
+ BSONObj expr = fromjson("{ $sortArray: { input: { $literal: undefined }, sortBy: -1 } }");
+
+ auto expressionSortArray =
+ ExpressionSortArray::parse(&expCtx, expr.firstElement(), expCtx.variablesParseState);
+ Value val = expressionSortArray->evaluate(MutableDocument().freeze(), &expCtx.variables);
+
+ ASSERT_VALUE_EQ(val, Value(BSONNULL));
+}
+
/* ------------------------- Old-style tests -------------------------- */
namespace Add {
diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h
index 7db0d760ee6..7961c549c16 100644
--- a/src/mongo/db/pipeline/expression_visitor.h
+++ b/src/mongo/db/pipeline/expression_visitor.h
@@ -104,6 +104,7 @@ class ExpressionSetIsSubset;
class ExpressionSetUnion;
class ExpressionSize;
class ExpressionReverseArray;
+class ExpressionSortArray;
class ExpressionSlice;
class ExpressionIsArray;
class ExpressionRandom;
@@ -262,6 +263,7 @@ public:
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSetUnion>) = 0;
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSize>) = 0;
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionReverseArray>) = 0;
+ virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSortArray>) = 0;
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionSlice>) = 0;
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionIsArray>) = 0;
virtual void visit(expression_walker::MaybeConstPtr<IsConst, ExpressionRandom>) = 0;
@@ -438,6 +440,7 @@ struct SelectiveConstExpressionVisitorBase : public ExpressionConstVisitor {
void visit(const ExpressionSetUnion*) override {}
void visit(const ExpressionSize*) override {}
void visit(const ExpressionReverseArray*) override {}
+ void visit(const ExpressionSortArray*) override {}
void visit(const ExpressionSlice*) override {}
void visit(const ExpressionIsArray*) override {}
void visit(const ExpressionRound*) override {}
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 5a7d2384eb5..46554b8c094 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -387,6 +387,7 @@ public:
void visit(const ExpressionSetUnion* expr) final {}
void visit(const ExpressionSize* expr) final {}
void visit(const ExpressionReverseArray* expr) final {}
+ void visit(const ExpressionSortArray* expr) final {}
void visit(const ExpressionSlice* expr) final {}
void visit(const ExpressionIsArray* expr) final {}
void visit(const ExpressionRound* expr) final {}
@@ -617,6 +618,7 @@ public:
void visit(const ExpressionSetUnion* expr) final {}
void visit(const ExpressionSize* expr) final {}
void visit(const ExpressionReverseArray* expr) final {}
+ void visit(const ExpressionSortArray* expr) final {}
void visit(const ExpressionSlice* expr) final {}
void visit(const ExpressionIsArray* expr) final {}
void visit(const ExpressionRound* expr) final {}
@@ -2572,6 +2574,11 @@ public:
_context->pushExpr(
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(exprRevArr)));
}
+
+ void visit(const ExpressionSortArray* expr) final {
+ unsupportedExpression(expr->getOpName());
+ }
+
void visit(const ExpressionSlice* expr) final {
unsupportedExpression(expr->getOpName());
}
diff --git a/src/mongo/db/update/SConscript b/src/mongo/db/update/SConscript
index a8f0bd1dc2b..3e0eb55927f 100644
--- a/src/mongo/db/update/SConscript
+++ b/src/mongo/db/update/SConscript
@@ -133,7 +133,7 @@ env.CppUnitTest(
'pull_node_test.cpp',
'pullall_node_test.cpp',
'push_node_test.cpp',
- 'push_sorter_test.cpp',
+ 'pattern_cmp_test.cpp',
'rename_node_test.cpp',
'set_node_test.cpp',
'unset_node_test.cpp',
diff --git a/src/mongo/db/update/pattern_cmp.h b/src/mongo/db/update/pattern_cmp.h
new file mode 100644
index 00000000000..c612bdec29f
--- /dev/null
+++ b/src/mongo/db/update/pattern_cmp.h
@@ -0,0 +1,171 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/bson/mutable/document.h"
+#include "mongo/bson/mutable/element.h"
+#include "mongo/db/bson/dotted_path_support.h"
+#include "mongo/db/exec/document_value/document.h"
+#include "mongo/db/exec/document_value/value_comparator.h"
+#include "mongo/db/field_ref.h"
+#include "mongo/db/jsobj.h"
+#include "mongo/db/query/collation/collator_interface.h"
+
+namespace mongo {
+
+namespace pattern_cmp {
+
+/**
+ * When we include a sort specification, that object should pass the checks in
+ * this function.
+ *
+ * Checks include:
+ * 1. The sort pattern cannot be empty.
+ * 2. The value of each pattern element is 1 or -1.
+ * 3. The sort field cannot be empty.
+ * 4. If the sort field is a dotted field, it does not have any empty parts.
+ */
+static Status checkSortClause(const BSONObj& sortObject) {
+ if (sortObject.isEmpty()) {
+ return Status(ErrorCodes::BadValue,
+ "The sort pattern is empty when it should be a set of fields.");
+ }
+
+ for (auto&& patternElement : sortObject) {
+ double orderVal = patternElement.isNumber() ? patternElement.Number() : 0;
+ if (orderVal != -1 && orderVal != 1) {
+ return Status(ErrorCodes::BadValue, "The sort element value must be either 1 or -1");
+ }
+
+ FieldRef sortField(patternElement.fieldName());
+ if (sortField.numParts() == 0) {
+ return Status(ErrorCodes::BadValue, "The sort field cannot be empty");
+ }
+
+ for (size_t i = 0; i < sortField.numParts(); ++i) {
+ if (sortField.getPart(i).size() == 0) {
+ return Status(ErrorCodes::BadValue,
+ str::stream() << "The sort field is a dotted field "
+ "but has an empty part: "
+ << sortField.dottedField());
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+} // namespace pattern_cmp
+
+// Extracts the value for 'pattern' for both 'lhs' and 'rhs' and return true if 'lhs' <
+// 'rhs'. We expect that both 'lhs' and 'rhs' be key patterns.
+class PatternElementCmp {
+public:
+ PatternElementCmp() = default;
+
+ PatternElementCmp(const BSONObj& pattern, const CollatorInterface* collator)
+ : sortPattern(pattern.copy()),
+ useWholeValue(sortPattern.hasField("")),
+ collator(collator) {}
+
+ bool operator()(const mutablebson::Element& lhs, const mutablebson::Element& rhs) const {
+ namespace dps = ::mongo::dotted_path_support;
+ if (useWholeValue) {
+ const int comparedValue = lhs.compareWithElement(rhs, collator, false);
+
+ const bool reversed = (sortPattern.firstElement().number() < 0);
+
+ return (reversed ? comparedValue > 0 : comparedValue < 0);
+ } else {
+ BSONObj lhsObj =
+ lhs.getType() == Object ? lhs.getValueObject() : lhs.getValue().wrap("");
+ BSONObj rhsObj =
+ rhs.getType() == Object ? rhs.getValueObject() : rhs.getValue().wrap("");
+
+ BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true);
+ BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true);
+
+ return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0;
+ }
+ }
+
+ BSONObj sortPattern;
+ bool useWholeValue = true;
+ const CollatorInterface* collator = nullptr;
+};
+
+class PatternValueCmp {
+public:
+ PatternValueCmp() = default;
+
+ PatternValueCmp(const BSONObj& pattern,
+ const BSONElement& originalElement,
+ const CollatorInterface* collator)
+ : sortPattern(pattern.copy()),
+ useWholeValue(sortPattern.hasField("")),
+ originalObj(BSONObj().addField(originalElement).copy()),
+ collator(collator) {}
+
+ bool operator()(const Value& lhs, const Value& rhs) const {
+ namespace dps = ::mongo::dotted_path_support;
+ if (useWholeValue) {
+ const bool ascending = ValueComparator().getLessThan()(lhs, rhs);
+
+ const bool reversed = (sortPattern.firstElement().number() < 0);
+
+ return (reversed ? !ascending : ascending);
+ } else {
+ BSONObj lhsObj = lhs.isObject() ? lhs.getDocument().toBson() : lhs.wrap("");
+ BSONObj rhsObj = rhs.isObject() ? rhs.getDocument().toBson() : rhs.wrap("");
+
+ BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true);
+ BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true);
+
+ return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0;
+ }
+ }
+
+ // Returns the original element passed into the PatternValueCmp constructor.
+ BSONElement getOriginalElement() const {
+ return originalObj.firstElement();
+ }
+
+ BSONObj sortPattern;
+ bool useWholeValue = true;
+
+ /**
+ * We store the original element as an object so that we can call the copy() method on it.
+ * This way, the PatternValueCmp class can have its own copy of the object.
+ */
+ BSONObj originalObj;
+ const CollatorInterface* collator = nullptr;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/update/push_sorter_test.cpp b/src/mongo/db/update/pattern_cmp_test.cpp
index 6dfa3219dd8..a6b0f44aadd 100644
--- a/src/mongo/db/update/push_sorter_test.cpp
+++ b/src/mongo/db/update/pattern_cmp_test.cpp
@@ -27,7 +27,7 @@
* it in the license file.
*/
-#include "mongo/db/update/push_sorter.h"
+#include "mongo/db/update/pattern_cmp.h"
#include "mongo/bson/mutable/algorithm.h"
#include "mongo/bson/mutable/document.h"
diff --git a/src/mongo/db/update/push_node.cpp b/src/mongo/db/update/push_node.cpp
index f428269a924..3a4e7c1705c 100644
--- a/src/mongo/db/update/push_node.cpp
+++ b/src/mongo/db/update/push_node.cpp
@@ -49,40 +49,6 @@ const StringData PushNode::kPositionClauseName = "$position";
namespace {
/**
- * When the $sort clause in a $push modifer is an object, that object should pass the checks in
- * this function.
- */
-Status checkSortClause(const BSONObj& sortObject) {
- if (sortObject.isEmpty()) {
- return Status(ErrorCodes::BadValue,
- "The $sort pattern is empty when it should be a set of fields.");
- }
-
- for (auto&& patternElement : sortObject) {
- double orderVal = patternElement.isNumber() ? patternElement.Number() : 0;
- if (orderVal != -1 && orderVal != 1) {
- return Status(ErrorCodes::BadValue, "The $sort element value must be either 1 or -1");
- }
-
- FieldRef sortField(patternElement.fieldName());
- if (sortField.numParts() == 0) {
- return Status(ErrorCodes::BadValue, "The $sort field cannot be empty");
- }
-
- for (size_t i = 0; i < sortField.numParts(); ++i) {
- if (sortField.getPart(i).size() == 0) {
- return Status(ErrorCodes::BadValue,
- str::stream() << "The $sort field is a dotted field "
- "but has an empty part: "
- << sortField.dottedField());
- }
- }
- }
-
- return Status::OK();
-}
-
-/**
* std::abs(LLONG_MIN) results in undefined behavior on 2's complement systems because the
* absolute value of LLONG_MIN cannot be represented in a 'long long'.
*
@@ -157,7 +123,7 @@ Status PushNode::init(BSONElement modExpr, const boost::intrusive_ptr<Expression
auto sortClause = sortIt->second;
if (sortClause.type() == BSONType::Object) {
- auto status = checkSortClause(sortClause.embeddedObject());
+ auto status = pattern_cmp::checkSortClause(sortClause.embeddedObject());
if (status.isOK()) {
_sort = PatternElementCmp(sortClause.embeddedObject(), expCtx->getCollator());
diff --git a/src/mongo/db/update/push_node.h b/src/mongo/db/update/push_node.h
index 38b4b00bc66..77cc6b5d125 100644
--- a/src/mongo/db/update/push_node.h
+++ b/src/mongo/db/update/push_node.h
@@ -36,7 +36,7 @@
#include "mongo/base/string_data.h"
#include "mongo/db/update/modifier_node.h"
-#include "mongo/db/update/push_sorter.h"
+#include "mongo/db/update/pattern_cmp.h"
namespace mongo {
diff --git a/src/mongo/db/update/push_sorter.h b/src/mongo/db/update/push_sorter.h
deleted file mode 100644
index 139529533ae..00000000000
--- a/src/mongo/db/update/push_sorter.h
+++ /dev/null
@@ -1,75 +0,0 @@
-/**
- * Copyright (C) 2018-present MongoDB, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the Server Side Public License, version 1,
- * as published by MongoDB, Inc.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * Server Side Public License for more details.
- *
- * You should have received a copy of the Server Side Public License
- * along with this program. If not, see
- * <http://www.mongodb.com/licensing/server-side-public-license>.
- *
- * As a special exception, the copyright holders give permission to link the
- * code of portions of this program with the OpenSSL library under certain
- * conditions as described in each individual source file and distribute
- * linked combinations including the program with the OpenSSL library. You
- * must comply with the Server Side Public License in all respects for
- * all of the code used other than as permitted herein. If you modify file(s)
- * with this exception, you may extend this exception to your version of the
- * file(s), but you are not obligated to do so. If you do not wish to do so,
- * delete this exception statement from your version. If you delete this
- * exception statement from all source files in the program, then also delete
- * it in the license file.
- */
-
-#pragma once
-
-#include "mongo/bson/mutable/document.h"
-#include "mongo/bson/mutable/element.h"
-#include "mongo/db/bson/dotted_path_support.h"
-#include "mongo/db/jsobj.h"
-#include "mongo/db/query/collation/collator_interface.h"
-
-namespace mongo {
-
-// Extracts the value for 'pattern' for both 'lhs' and 'rhs' and return true if 'lhs' <
-// 'rhs'. We expect that both 'lhs' and 'rhs' be key patterns.
-struct PatternElementCmp {
- PatternElementCmp() = default;
-
- PatternElementCmp(const BSONObj& pattern, const CollatorInterface* collator)
- : sortPattern(pattern), useWholeValue(pattern.hasField("")), collator(collator) {}
-
- bool operator()(const mutablebson::Element& lhs, const mutablebson::Element& rhs) const {
- namespace dps = ::mongo::dotted_path_support;
- if (useWholeValue) {
- const int comparedValue = lhs.compareWithElement(rhs, collator, false);
-
- const bool reversed = (sortPattern.firstElement().number() < 0);
-
- return (reversed ? comparedValue > 0 : comparedValue < 0);
- } else {
- // TODO: Push on to mutable in the future, and to support non-contiguous Elements.
- BSONObj lhsObj =
- lhs.getType() == Object ? lhs.getValueObject() : lhs.getValue().wrap("");
- BSONObj rhsObj =
- rhs.getType() == Object ? rhs.getValueObject() : rhs.getValue().wrap("");
-
- BSONObj lhsKey = dps::extractElementsBasedOnTemplate(lhsObj, sortPattern, true);
- BSONObj rhsKey = dps::extractElementsBasedOnTemplate(rhsObj, sortPattern, true);
-
- return lhsKey.woCompare(rhsKey, sortPattern, false, collator) < 0;
- }
- }
-
- BSONObj sortPattern;
- bool useWholeValue = true;
- const CollatorInterface* collator = nullptr;
-};
-
-} // namespace mongo