From a5af803b678126c1ef5fb7a2a59414c9484cc21d Mon Sep 17 00:00:00 2001 From: Ivan Fefer Date: Mon, 9 Jan 2023 09:23:09 +0000 Subject: SERVER-71581 Convert set expressions to ABT --- jstests/aggregation/expressions/set.js | 193 +++++++++++++-------- jstests/libs/sbe_assert_error_override.js | 7 + .../db/query/sbe_stage_builder_expression.cpp | 108 +++++++++--- src/mongo/db/query/sbe_stage_builder_helpers.cpp | 7 +- src/mongo/db/query/sbe_stage_builder_helpers.h | 1 + 5 files changed, 221 insertions(+), 95 deletions(-) diff --git a/jstests/aggregation/expressions/set.js b/jstests/aggregation/expressions/set.js index 4dbc19b32df..d90ddc20919 100644 --- a/jstests/aggregation/expressions/set.js +++ b/jstests/aggregation/expressions/set.js @@ -3,7 +3,8 @@ */ (function() { "use strict"; -load("jstests/aggregation/extras/utils.js"); +load("jstests/aggregation/extras/utils.js"); // For assertErrorCode. +load('jstests/libs/sbe_assert_error_override.js'); // Override error-code-checking APIs. const coll = db.expression_set; coll.drop(); @@ -19,84 +20,132 @@ assert.commandWorked(coll.insert([ {_id: 7, arr1: [1, 2, 3], arr2: [1, 1, 2, 2, 3, 3]}, ])); -const result = coll.aggregate([ - {$sort: {_id: 1}}, - { - $project: { - union: {$setUnion: ["$arr1", "$arr2"]}, - intersection: {$setIntersection: ["$arr1", "$arr2"]}, - difference: {$setDifference: ["$arr1", "$arr2"]}, - isSubset: {$setIsSubset: ["$arr1", "$arr2"]}, - equals: {$setEquals: ["$arr1", "$arr2"]}, - } - } - ]) - .toArray(); - // The order of the output array elements is undefined for $setUnion, $setDifference and // $setIntersection expressions. Hence we do a sort operation to get a consistent order. -const sortSetFields = document => Object.assign(document, { - union: document.union.sort(), - intersection: document.intersection.sort(), - difference: document.difference.sort(), -}); +const sortSetFields = doc => { + let result = {}; + for (const key in doc) { + if (doc.hasOwnProperty(key)) { + const value = doc[key]; + result[key] = Array.isArray(value) ? value.sort() : value; + } + } + return result; +}; + +const runTest = function(pipeline, expectedResult) { + pipeline.push({$sort: {_id: 1}}); + const result = coll.aggregate(pipeline).toArray(); + assert.eq(expectedResult, result.map(sortSetFields)); +}; + +runTest( + [{$project: {union: {$setUnion: ["$arr1", "$arr2"]}}}], + [ + {_id: 0, union: [1, 2, 3, 4]}, + {_id: 1, union: [1, 2, 3, 4, 5, 6]}, + {_id: 2, union: [1, 2, 3]}, + {_id: 3, union: [4, 5, 6]}, + {_id: 4, union: [1, 2, 3]}, + {_id: 5, union: [2, 3, 4]}, + {_id: 6, union: [1, 2, 3]}, + {_id: 7, union: [1, 2, 3]}, + ], +); + +runTest( + [{$project: {intersection: {$setIntersection: ["$arr1", "$arr2"]}}}], + [ + {_id: 0, intersection: [2, 3]}, + {_id: 1, intersection: []}, + {_id: 2, intersection: []}, + {_id: 3, intersection: []}, + {_id: 4, intersection: [2, 3]}, + {_id: 5, intersection: [2, 3]}, + {_id: 6, intersection: [1, 2, 3]}, + {_id: 7, intersection: [1, 2, 3]}, + ], +); -assert.eq(result.map(sortSetFields), [ - { - _id: 0, - union: [1, 2, 3, 4], - intersection: [2, 3], - difference: [1], - isSubset: false, - equals: false - }, - { - _id: 1, - union: [1, 2, 3, 4, 5, 6], - intersection: [], - difference: [1, 2, 3], - isSubset: false, - equals: false - }, - { - _id: 2, - union: [1, 2, 3], - intersection: [], - difference: [1, 2, 3], - isSubset: false, - equals: false - }, - {_id: 3, union: [4, 5, 6], intersection: [], difference: [], isSubset: true, equals: false}, - { - _id: 4, - union: [1, 2, 3], - intersection: [2, 3], - difference: [1], - isSubset: false, - equals: false - }, - {_id: 5, union: [2, 3, 4], intersection: [2, 3], difference: [], isSubset: true, equals: false}, - { - _id: 6, - union: [1, 2, 3], - intersection: [1, 2, 3], - difference: [], - isSubset: true, - equals: true - }, - { - _id: 7, - union: [1, 2, 3], - intersection: [1, 2, 3], - difference: [], - isSubset: true, - equals: true - }, -]); +runTest( + [{$project: {difference: {$setDifference: ["$arr1", "$arr2"]}}}], + [ + {_id: 0, difference: [1]}, + {_id: 1, difference: [1, 2, 3]}, + {_id: 2, difference: [1, 2, 3]}, + {_id: 3, difference: []}, + {_id: 4, difference: [1]}, + {_id: 5, difference: []}, + {_id: 6, difference: []}, + {_id: 7, difference: []}, + ], +); + +runTest( + [{$project: {difference: {$setDifference: ["$arr2", "$arr1"]}}}], + [ + {_id: 0, difference: [4]}, + {_id: 1, difference: [4, 5, 6]}, + {_id: 2, difference: []}, + {_id: 3, difference: [4, 5, 6]}, + {_id: 4, difference: []}, + {_id: 5, difference: [4]}, + {_id: 6, difference: []}, + {_id: 7, difference: []}, + ], +); + +runTest( + [{$project: {equals: {$setEquals: ["$arr1", "$arr2"]}}}], + [ + {_id: 0, equals: false}, + {_id: 1, equals: false}, + {_id: 2, equals: false}, + {_id: 3, equals: false}, + {_id: 4, equals: false}, + {_id: 5, equals: false}, + {_id: 6, equals: true}, + {_id: 7, equals: true}, + ], +); + +runTest( + [{$project: {isSubset: {$setIsSubset: ["$arr1", "$arr2"]}}}], + [ + {_id: 0, isSubset: false}, + {_id: 1, isSubset: false}, + {_id: 2, isSubset: false}, + {_id: 3, isSubset: true}, + {_id: 4, isSubset: false}, + {_id: 5, isSubset: true}, + {_id: 6, isSubset: true}, + {_id: 7, isSubset: true}, + ], +); // No sets to union should produce an empty set for all records so we only check the first one. assert.eq(coll.aggregate([{$project: {x: {$setUnion: []}}}]).toArray()[0]['x'], []); // No sets to intersect should produce an empty set for all records so we only check the first one. assert.eq(coll.aggregate([{$project: {x: {$setIntersection: []}}}]).toArray()[0]['x'], []); + +const operators = [ + ["$setUnion", 17043], + ["$setIntersection", 17047], + ["$setDifference", [17048, 17049]], + ["$setEquals", 17044], + ["$setIsSubset", [17042, 17046]] +]; +const badDocuments = [ + {arr1: "123", arr2: [1, 2, 3]}, + {arr1: [1, 2, 3], arr2: "123"}, + {arr1: "123", arr2: "123"}, +]; +for (const [operator, errorCodes] of operators) { + for (const badDocument of badDocuments) { + assert(coll.drop()); + assert.commandWorked(coll.insertOne(badDocument)); + assertErrorCode(coll, [{$project: {output: {[operator]: ["$arr1", "$arr2"]}}}], errorCodes); + } +} }()); diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js index 57bc8625176..7bc8e0504e6 100644 --- a/jstests/libs/sbe_assert_error_override.js +++ b/jstests/libs/sbe_assert_error_override.js @@ -129,6 +129,13 @@ const equivalentErrorCodesList = [ [5439105, 5439018, 7003906], [5439106, 5439015, 7003909], [5439107, 5439016, 7003910], + [17042, 5126900, 7158100], + [17043, 5126900, 7158100], + [17044, 5126900, 7158100], + [17046, 5126900, 7158100], + [17047, 5126900, 7158100], + [17048, 5126900, 7158100], + [17049, 5126900, 7158100], ]; // This map is generated based on the contents of 'equivalentErrorCodesList'. This map should _not_ diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 837282c485e..94939ed33c4 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -3871,6 +3871,9 @@ private: using namespace std::literals; size_t arity = expr->getChildren().size(); + if (_context->hasAllAbtEligibleEntries(arity)) { + return generateABTSetExpression(expr, setOp); + } _context->ensureArity(arity); auto frameId = _context->state.frameId(); @@ -3889,27 +3892,8 @@ private: checkExprsNotArray.reserve(arity); auto collatorSlot = _context->state.data->env->getSlotIfExists("collator"_sd); - - auto [operatorName, setFunctionName] = [setOp, collatorSlot]() { - switch (setOp) { - case SetOperation::Difference: - return std::make_pair("setDifference"_sd, - collatorSlot ? "collSetDifference"_sd - : "setDifference"_sd); - case SetOperation::Intersection: - return std::make_pair("setIntersection"_sd, - collatorSlot ? "collSetIntersection"_sd - : "setIntersection"_sd); - case SetOperation::Union: - return std::make_pair("setUnion"_sd, - collatorSlot ? "collSetUnion"_sd : "setUnion"_sd); - case SetOperation::Equals: - return std::make_pair("setEquals"_sd, - collatorSlot ? "collSetEquals"_sd : "setEquals"_sd); - default: - MONGO_UNREACHABLE; - } - }(); + auto [operatorName, setFunctionName] = + getSetOperatorAndFunctionNames(setOp, collatorSlot.has_value()); if (collatorSlot) { argVars.push_back(sbe::makeE(*collatorSlot)); @@ -3954,6 +3938,88 @@ private: sbe::makeE(frameId, std::move(binds), std::move(setExpr))); } + void generateABTSetExpression(const Expression* expr, SetOperation setOp) { + using namespace std::literals; + + size_t arity = expr->getChildren().size(); + _context->ensureArity(arity); + + optimizer::ABTVector args; + optimizer::ProjectionNameVector argNames; + optimizer::ABTVector variables; + + optimizer::ABTVector checkNulls; + optimizer::ABTVector checkNotArrays; + + auto collatorSlot = _context->state.data->env->getSlotIfExists("collator"_sd); + + args.reserve(arity); + argNames.reserve(arity); + variables.reserve(arity + (collatorSlot.has_value() ? 1 : 0)); + checkNulls.reserve(arity); + checkNotArrays.reserve(arity); + + auto [operatorName, setFunctionName] = + getSetOperatorAndFunctionNames(setOp, collatorSlot.has_value()); + if (collatorSlot) { + variables.push_back( + optimizer::make(_context->registerVariable(*collatorSlot))); + } + + for (size_t idx = 0; idx < arity; ++idx) { + args.push_back(_context->popABTExpr()); + auto argName = makeLocalVariableName(_context->state.frameId(), 0); + argNames.push_back(argName); + variables.push_back(optimizer::make(argName)); + + checkNulls.push_back(generateABTNullOrMissing(argName)); + checkNotArrays.push_back(generateABTNonArrayCheck(std::move(argName))); + } + // Reverse the args array to preserve the original order of the arguments, since some set + // operations, such as $setDifference, are not commutative. + std::reverse(std::begin(args), std::end(args)); + + auto checkNullAnyArgument = + makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkNulls)); + auto checkNotArrayAnyArgument = + makeBalancedBooleanOpTree(optimizer::Operations::Or, std::move(checkNotArrays)); + auto setExpr = buildABTMultiBranchConditional( + ABTCaseValuePair{std::move(checkNullAnyArgument), optimizer::Constant::null()}, + ABTCaseValuePair{std::move(checkNotArrayAnyArgument), + makeABTFail(ErrorCodes::Error{7158100}, + str::stream() << "All operands of $" << operatorName + << " must be arrays.")}, + optimizer::make(setFunctionName.toString(), + std::move(variables))); + + for (size_t i = 0; i < arity; ++i) { + setExpr = optimizer::make( + std::move(argNames[i]), std::move(args[i]), setExpr); + } + + _context->pushExpr(std::move(setExpr)); + } + + std::pair getSetOperatorAndFunctionNames(SetOperation setOp, + bool hasCollator) const { + switch (setOp) { + case SetOperation::Difference: + return std::make_pair("setDifference"_sd, + hasCollator ? "collSetDifference"_sd : "setDifference"_sd); + case SetOperation::Intersection: + return std::make_pair("setIntersection"_sd, + hasCollator ? "collSetIntersection"_sd + : "setIntersection"_sd); + case SetOperation::Union: + return std::make_pair("setUnion"_sd, + hasCollator ? "collSetUnion"_sd : "setUnion"_sd); + case SetOperation::Equals: + return std::make_pair("setEquals"_sd, + hasCollator ? "collSetEquals"_sd : "setEquals"_sd); + } + MONGO_UNREACHABLE; + } + /** * Shared expression building logic for regex expressions. */ diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp index 762c55d0e79..eb8b614d414 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -1413,8 +1413,11 @@ optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var) { } optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var) { - return makeNot(optimizer::make( - "isString", optimizer::ABTVector{optimizer::make(var)})); + return makeNot(makeABTFunction("isString", optimizer::make(var))); +} + +optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var) { + return makeNot(makeABTFunction("isArray", optimizer::make(var))); } optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var) { diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h index 74b27315ae0..dced2412675 100644 --- a/src/mongo/db/query/sbe_stage_builder_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -1249,6 +1249,7 @@ optimizer::ProjectionName makeLocalVariableName(sbe::FrameId frameId, sbe::value optimizer::ABT generateABTNullOrMissing(optimizer::ProjectionName var); optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var); +optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var); optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var); /** -- cgit v1.2.1