summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorIvan Fefer <ivan.fefer@mongodb.com>2023-01-09 09:23:09 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-01-09 10:03:24 +0000
commita5af803b678126c1ef5fb7a2a59414c9484cc21d (patch)
treea798412e0b11ceb5cf01f07a992b7ac0638b8515
parentad0c764b6f5020033092a91164e0187add0deb94 (diff)
downloadmongo-a5af803b678126c1ef5fb7a2a59414c9484cc21d.tar.gz
SERVER-71581 Convert set expressions to ABT
-rw-r--r--jstests/aggregation/expressions/set.js193
-rw-r--r--jstests/libs/sbe_assert_error_override.js7
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp108
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.cpp7
-rw-r--r--src/mongo/db/query/sbe_stage_builder_helpers.h1
5 files changed, 221 insertions, 95 deletions
diff --git a/jstests/aggregation/expressions/set.js b/jstests/aggregation/expressions/set.js
index 4dbc19b32df..d90ddc20919 100644
--- a/jstests/aggregation/expressions/set.js
+++ b/jstests/aggregation/expressions/set.js
@@ -3,7 +3,8 @@
*/
(function() {
"use strict";
-load("jstests/aggregation/extras/utils.js");
+load("jstests/aggregation/extras/utils.js"); // For assertErrorCode.
+load('jstests/libs/sbe_assert_error_override.js'); // Override error-code-checking APIs.
const coll = db.expression_set;
coll.drop();
@@ -19,84 +20,132 @@ assert.commandWorked(coll.insert([
{_id: 7, arr1: [1, 2, 3], arr2: [1, 1, 2, 2, 3, 3]},
]));
-const result = coll.aggregate([
- {$sort: {_id: 1}},
- {
- $project: {
- union: {$setUnion: ["$arr1", "$arr2"]},
- intersection: {$setIntersection: ["$arr1", "$arr2"]},
- difference: {$setDifference: ["$arr1", "$arr2"]},
- isSubset: {$setIsSubset: ["$arr1", "$arr2"]},
- equals: {$setEquals: ["$arr1", "$arr2"]},
- }
- }
- ])
- .toArray();
-
// The order of the output array elements is undefined for $setUnion, $setDifference and
// $setIntersection expressions. Hence we do a sort operation to get a consistent order.
-const sortSetFields = document => Object.assign(document, {
- union: document.union.sort(),
- intersection: document.intersection.sort(),
- difference: document.difference.sort(),
-});
+const sortSetFields = doc => {
+ let result = {};
+ for (const key in doc) {
+ if (doc.hasOwnProperty(key)) {
+ const value = doc[key];
+ result[key] = Array.isArray(value) ? value.sort() : value;
+ }
+ }
+ return result;
+};
+
+const runTest = function(pipeline, expectedResult) {
+ pipeline.push({$sort: {_id: 1}});
+ const result = coll.aggregate(pipeline).toArray();
+ assert.eq(expectedResult, result.map(sortSetFields));
+};
+
+runTest(
+ [{$project: {union: {$setUnion: ["$arr1", "$arr2"]}}}],
+ [
+ {_id: 0, union: [1, 2, 3, 4]},
+ {_id: 1, union: [1, 2, 3, 4, 5, 6]},
+ {_id: 2, union: [1, 2, 3]},
+ {_id: 3, union: [4, 5, 6]},
+ {_id: 4, union: [1, 2, 3]},
+ {_id: 5, union: [2, 3, 4]},
+ {_id: 6, union: [1, 2, 3]},
+ {_id: 7, union: [1, 2, 3]},
+ ],
+);
+
+runTest(
+ [{$project: {intersection: {$setIntersection: ["$arr1", "$arr2"]}}}],
+ [
+ {_id: 0, intersection: [2, 3]},
+ {_id: 1, intersection: []},
+ {_id: 2, intersection: []},
+ {_id: 3, intersection: []},
+ {_id: 4, intersection: [2, 3]},
+ {_id: 5, intersection: [2, 3]},
+ {_id: 6, intersection: [1, 2, 3]},
+ {_id: 7, intersection: [1, 2, 3]},
+ ],
+);
-assert.eq(result.map(sortSetFields), [
- {
- _id: 0,
- union: [1, 2, 3, 4],
- intersection: [2, 3],
- difference: [1],
- isSubset: false,
- equals: false
- },
- {
- _id: 1,
- union: [1, 2, 3, 4, 5, 6],
- intersection: [],
- difference: [1, 2, 3],
- isSubset: false,
- equals: false
- },
- {
- _id: 2,
- union: [1, 2, 3],
- intersection: [],
- difference: [1, 2, 3],
- isSubset: false,
- equals: false
- },
- {_id: 3, union: [4, 5, 6], intersection: [], difference: [], isSubset: true, equals: false},
- {
- _id: 4,
- union: [1, 2, 3],
- intersection: [2, 3],
- difference: [1],
- isSubset: false,
- equals: false
- },
- {_id: 5, union: [2, 3, 4], intersection: [2, 3], difference: [], isSubset: true, equals: false},
- {
- _id: 6,
- union: [1, 2, 3],
- intersection: [1, 2, 3],
- difference: [],
- isSubset: true,
- equals: true
- },
- {
- _id: 7,
- union: [1, 2, 3],
- intersection: [1, 2, 3],
- difference: [],
- isSubset: true,
- equals: true
- },
-]);
+runTest(
+ [{$project: {difference: {$setDifference: ["$arr1", "$arr2"]}}}],
+ [
+ {_id: 0, difference: [1]},
+ {_id: 1, difference: [1, 2, 3]},
+ {_id: 2, difference: [1, 2, 3]},
+ {_id: 3, difference: []},
+ {_id: 4, difference: [1]},
+ {_id: 5, difference: []},
+ {_id: 6, difference: []},
+ {_id: 7, difference: []},
+ ],
+);
+
+runTest(
+ [{$project: {difference: {$setDifference: ["$arr2", "$arr1"]}}}],
+ [
+ {_id: 0, difference: [4]},
+ {_id: 1, difference: [4, 5, 6]},
+ {_id: 2, difference: []},
+ {_id: 3, difference: [4, 5, 6]},
+ {_id: 4, difference: []},
+ {_id: 5, difference: [4]},
+ {_id: 6, difference: []},
+ {_id: 7, difference: []},
+ ],
+);
+
+runTest(
+ [{$project: {equals: {$setEquals: ["$arr1", "$arr2"]}}}],
+ [
+ {_id: 0, equals: false},
+ {_id: 1, equals: false},
+ {_id: 2, equals: false},
+ {_id: 3, equals: false},
+ {_id: 4, equals: false},
+ {_id: 5, equals: false},
+ {_id: 6, equals: true},
+ {_id: 7, equals: true},
+ ],
+);
+
+runTest(
+ [{$project: {isSubset: {$setIsSubset: ["$arr1", "$arr2"]}}}],
+ [
+ {_id: 0, isSubset: false},
+ {_id: 1, isSubset: false},
+ {_id: 2, isSubset: false},
+ {_id: 3, isSubset: true},
+ {_id: 4, isSubset: false},
+ {_id: 5, isSubset: true},
+ {_id: 6, isSubset: true},
+ {_id: 7, isSubset: true},
+ ],
+);
// No sets to union should produce an empty set for all records so we only check the first one.
assert.eq(coll.aggregate([{$project: {x: {$setUnion: []}}}]).toArray()[0]['x'], []);
// No sets to intersect should produce an empty set for all records so we only check the first one.
assert.eq(coll.aggregate([{$project: {x: {$setIntersection: []}}}]).toArray()[0]['x'], []);
+
+const operators = [
+ ["$setUnion", 17043],
+ ["$setIntersection", 17047],
+ ["$setDifference", [17048, 17049]],
+ ["$setEquals", 17044],
+ ["$setIsSubset", [17042, 17046]]
+];
+const badDocuments = [
+ {arr1: "123", arr2: [1, 2, 3]},
+ {arr1: [1, 2, 3], arr2: "123"},
+ {arr1: "123", arr2: "123"},
+];
+for (const [operator, errorCodes] of operators) {
+ for (const badDocument of badDocuments) {
+ assert(coll.drop());
+ assert.commandWorked(coll.insertOne(badDocument));
+ assertErrorCode(coll, [{$project: {output: {[operator]: ["$arr1", "$arr2"]}}}], errorCodes);
+ }
+}
}());
diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js
index 57bc8625176..7bc8e0504e6 100644
--- a/jstests/libs/sbe_assert_error_override.js
+++ b/jstests/libs/sbe_assert_error_override.js
@@ -129,6 +129,13 @@ const equivalentErrorCodesList = [
[5439105, 5439018, 7003906],
[5439106, 5439015, 7003909],
[5439107, 5439016, 7003910],
+ [17042, 5126900, 7158100],
+ [17043, 5126900, 7158100],
+ [17044, 5126900, 7158100],
+ [17046, 5126900, 7158100],
+ [17047, 5126900, 7158100],
+ [17048, 5126900, 7158100],
+ [17049, 5126900, 7158100],
];
// This map is generated based on the contents of 'equivalentErrorCodesList'. This map should _not_
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index 837282c485e..94939ed33c4 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -3871,6 +3871,9 @@ private:
using namespace std::literals;
size_t arity = expr->getChildren().size();
+ if (_context->hasAllAbtEligibleEntries(arity)) {
+ return generateABTSetExpression(expr, setOp);
+ }
_context->ensureArity(arity);
auto frameId = _context->state.frameId();
@@ -3889,27 +3892,8 @@ private:
checkExprsNotArray.reserve(arity);
auto collatorSlot = _context->state.data->env->getSlotIfExists("collator"_sd);
-
- auto [operatorName, setFunctionName] = [setOp, collatorSlot]() {
- switch (setOp) {
- case SetOperation::Difference:
- return std::make_pair("setDifference"_sd,
- collatorSlot ? "collSetDifference"_sd
- : "setDifference"_sd);
- case SetOperation::Intersection:
- return std::make_pair("setIntersection"_sd,
- collatorSlot ? "collSetIntersection"_sd
- : "setIntersection"_sd);
- case SetOperation::Union:
- return std::make_pair("setUnion"_sd,
- collatorSlot ? "collSetUnion"_sd : "setUnion"_sd);
- case SetOperation::Equals:
- return std::make_pair("setEquals"_sd,
- collatorSlot ? "collSetEquals"_sd : "setEquals"_sd);
- default:
- MONGO_UNREACHABLE;
- }
- }();
+ auto [operatorName, setFunctionName] =
+ getSetOperatorAndFunctionNames(setOp, collatorSlot.has_value());
if (collatorSlot) {
argVars.push_back(sbe::makeE<sbe::EVariable>(*collatorSlot));
@@ -3954,6 +3938,88 @@ private:
sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(setExpr)));
}
+ void generateABTSetExpression(const Expression* expr, SetOperation setOp) {
+ using namespace std::literals;
+
+ size_t arity = expr->getChildren().size();
+ _context->ensureArity(arity);
+
+ optimizer::ABTVector args;
+ optimizer::ProjectionNameVector argNames;
+ optimizer::ABTVector variables;
+
+ optimizer::ABTVector checkNulls;
+ optimizer::ABTVector checkNotArrays;
+
+ auto collatorSlot = _context->state.data->env->getSlotIfExists("collator"_sd);
+
+ args.reserve(arity);
+ argNames.reserve(arity);
+ variables.reserve(arity + (collatorSlot.has_value() ? 1 : 0));
+ checkNulls.reserve(arity);
+ checkNotArrays.reserve(arity);
+
+ auto [operatorName, setFunctionName] =
+ getSetOperatorAndFunctionNames(setOp, collatorSlot.has_value());
+ if (collatorSlot) {
+ variables.push_back(
+ optimizer::make<optimizer::Variable>(_context->registerVariable(*collatorSlot)));
+ }
+
+ for (size_t idx = 0; idx < arity; ++idx) {
+ args.push_back(_context->popABTExpr());
+ auto argName = makeLocalVariableName(_context->state.frameId(), 0);
+ argNames.push_back(argName);
+ variables.push_back(optimizer::make<optimizer::Variable>(argName));
+
+ checkNulls.push_back(generateABTNullOrMissing(argName));
+ checkNotArrays.push_back(generateABTNonArrayCheck(std::move(argName)));
+ }
+ // Reverse the args array to preserve the original order of the arguments, since some set
+ // operations, such as $setDifference, are not commutative.
+ std::reverse(std::begin(args), std::end(args));
+
+ auto checkNullAnyArgument =
+ makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkNulls));
+ auto checkNotArrayAnyArgument =
+ makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkNotArrays));
+ auto setExpr = buildABTMultiBranchConditional(
+ ABTCaseValuePair{std::move(checkNullAnyArgument), optimizer::Constant::null()},
+ ABTCaseValuePair{std::move(checkNotArrayAnyArgument),
+ makeABTFail(ErrorCodes::Error{7158100},
+ str::stream() << "All operands of $" << operatorName
+ << " must be arrays.")},
+ optimizer::make<optimizer::FunctionCall>(setFunctionName.toString(),
+ std::move(variables)));
+
+ for (size_t i = 0; i < arity; ++i) {
+ setExpr = optimizer::make<optimizer::Let>(
+ std::move(argNames[i]), std::move(args[i]), setExpr);
+ }
+
+ _context->pushExpr(std::move(setExpr));
+ }
+
+ std::pair<StringData, StringData> getSetOperatorAndFunctionNames(SetOperation setOp,
+ bool hasCollator) const {
+ switch (setOp) {
+ case SetOperation::Difference:
+ return std::make_pair("setDifference"_sd,
+ hasCollator ? "collSetDifference"_sd : "setDifference"_sd);
+ case SetOperation::Intersection:
+ return std::make_pair("setIntersection"_sd,
+ hasCollator ? "collSetIntersection"_sd
+ : "setIntersection"_sd);
+ case SetOperation::Union:
+ return std::make_pair("setUnion"_sd,
+ hasCollator ? "collSetUnion"_sd : "setUnion"_sd);
+ case SetOperation::Equals:
+ return std::make_pair("setEquals"_sd,
+ hasCollator ? "collSetEquals"_sd : "setEquals"_sd);
+ }
+ MONGO_UNREACHABLE;
+ }
+
/**
* Shared expression building logic for regex expressions.
*/
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
index 762c55d0e79..eb8b614d414 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp
@@ -1413,8 +1413,11 @@ optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var) {
}
optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var) {
- return makeNot(optimizer::make<optimizer::FunctionCall>(
- "isString", optimizer::ABTVector{optimizer::make<optimizer::Variable>(var)}));
+ return makeNot(makeABTFunction("isString", optimizer::make<optimizer::Variable>(var)));
+}
+
+optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var) {
+ return makeNot(makeABTFunction("isArray", optimizer::make<optimizer::Variable>(var)));
}
optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var) {
diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h
index 74b27315ae0..dced2412675 100644
--- a/src/mongo/db/query/sbe_stage_builder_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_helpers.h
@@ -1249,6 +1249,7 @@ optimizer::ProjectionName makeLocalVariableName(sbe::FrameId frameId, sbe::value
optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var);
optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var);
+optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var);
optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var);
optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var);
/**