diff options
author | Nikita Lapkov <nikita.lapkov@mongodb.com> | 2020-09-25 14:49:29 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-10-07 12:39:18 +0000 |
commit | 05b8d1423832addcbccd7ec04448ea13fc6ba1c2 (patch) | |
tree | 8972a3ebf8afe4d6dddafc9d4f584f0c172217cd | |
parent | 09a0264b07a47ad038e873e7c3cf9355a5266c81 (diff) | |
download | mongo-05b8d1423832addcbccd7ec04448ea13fc6ba1c2.tar.gz |
SERVER-50732 Support $filter expression in SBE
-rw-r--r-- | jstests/aggregation/bugs/server17943.js | 94 | ||||
-rw-r--r-- | jstests/aggregation/expressions/filter.js | 368 | ||||
-rw-r--r-- | jstests/libs/sbe_assert_error_override.js | 1 | ||||
-rw-r--r-- | src/mongo/db/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 4 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 275 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_filter.cpp | 29 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.cpp | 120 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_helpers.h | 125 |
9 files changed, 791 insertions, 226 deletions
diff --git a/jstests/aggregation/bugs/server17943.js b/jstests/aggregation/bugs/server17943.js deleted file mode 100644 index 7a0bc52c2a6..00000000000 --- a/jstests/aggregation/bugs/server17943.js +++ /dev/null @@ -1,94 +0,0 @@ -// SERVER-17943: Add $filter aggregation expression. -// @tags: [ -// sbe_incompatible, -// ] - -// For assertErrorCode. -load('jstests/aggregation/extras/utils.js'); - -(function() { -'use strict'; - -var coll = db.agg_filter_expr; -coll.drop(); - -assert.commandWorked(coll.insert({_id: 0, a: [1, 2, 3, 4, 5]})); -assert.commandWorked(coll.insert({_id: 1, a: [2, 4]})); -assert.commandWorked(coll.insert({_id: 2, a: []})); -assert.commandWorked(coll.insert({_id: 3, a: [1]})); -assert.commandWorked(coll.insert({_id: 4, a: null})); -assert.commandWorked(coll.insert({_id: 5, a: undefined})); -assert.commandWorked(coll.insert({_id: 6})); - -// Create filter to only accept odd numbers. -filterDoc = {input: '$a', as: 'x', cond: {$eq: [1, {$mod: ['$$x', 2]}]}}; -var expectedResults = [ - {_id: 0, b: [1, 3, 5]}, - {_id: 1, b: []}, - {_id: 2, b: []}, - {_id: 3, b: [1]}, - {_id: 4, b: null}, - {_id: 5, b: null}, - {_id: 6, b: null}, -]; -var results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}, {$sort: {_id: 1}}]).toArray(); -assert.eq(results, expectedResults); - -// create filter that uses the default variable name in 'cond' -filterDoc = { - input: '$a', - cond: {$eq: [2, '$$this']} -}; -expectedResults = [ - {_id: 0, b: [2]}, - {_id: 1, b: [2]}, - {_id: 2, b: []}, - {_id: 3, b: []}, - {_id: 4, b: null}, - {_id: 5, b: null}, - {_id: 6, b: null}, -]; -results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}, {$sort: {_id: 1}}]).toArray(); -assert.eq(results, expectedResults); - -// Invalid filter expressions. - -// '$filter' is not a document. -var filterDoc = 'string'; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28646); - -// Extra field(s). -filterDoc = {input: '$a', as: 'x', cond: true, extra: 1}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28647); - -// Missing 'input'. -filterDoc = { - as: 'x', - cond: true -}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28648); - -// Missing 'cond'. -filterDoc = {input: '$a', as: 'x'}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28650); - -// 'as' is not a valid variable name. -filterDoc = {input: '$a', as: '$x', cond: true}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 16867); - -// 'input' is not an array. -filterDoc = {input: 'string', as: 'x', cond: true}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28651); - -// 'cond' uses undefined variable name. -filterDoc = { - input: '$a', - cond: {$eq: [1, '$$var']} -}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 17276); - -assert(coll.drop()); -assert.commandWorked(coll.insert({a: 'string'})); -filterDoc = {input: '$a', as: 'x', cond: true}; -assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28651); -}()); diff --git a/jstests/aggregation/expressions/filter.js b/jstests/aggregation/expressions/filter.js new file mode 100644 index 00000000000..70f3deca9b2 --- /dev/null +++ b/jstests/aggregation/expressions/filter.js @@ -0,0 +1,368 @@ +// Test $filter aggregation expression. +// +// @tags: [ +// # Can't set the 'failOnPoisonedFieldLookup' failpoint on mongos. +// assumes_against_mongod_not_mongos, +// ] + +load('jstests/aggregation/extras/utils.js'); // For assertErrorCode. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. + +(function() { +'use strict'; + +var coll = db.agg_filter_expr; +coll.drop(); + +assert.commandWorked(coll.insert({_id: 0, c: 1, d: 3, a: [1, 2, 3, 4, 5]})); +assert.commandWorked(coll.insert({_id: 1, c: 2, d: 4, a: [1, 2]})); +assert.commandWorked(coll.insert({_id: 2, c: 3, d: 5, a: []})); +assert.commandWorked(coll.insert({_id: 3, c: 4, d: 6, a: [4]})); +assert.commandWorked(coll.insert({_id: 4, c: 5, d: 7, a: null})); +assert.commandWorked(coll.insert({_id: 5, c: 6, d: 8, a: undefined})); +assert.commandWorked(coll.insert({_id: 6, c: 7, d: 9})); + +// Create filter to only accept numbers greater than 2. +filterDoc = {input: '$a', as: 'x', cond: {$gt: ['$$x', 2]}}; +var expectedResults = [ + {_id: 0, b: [3, 4, 5]}, + {_id: 1, b: []}, + {_id: 2, b: []}, + {_id: 3, b: [4]}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +var results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}, {$sort: {_id: 1}}]).toArray(); +assert.eq(results, expectedResults); + +// Create filter that uses the default variable name in 'cond'. +filterDoc = { + input: '$a', + cond: {$eq: [2, '$$this']} +}; +expectedResults = [ + {_id: 0, b: [2]}, + {_id: 1, b: [2]}, + {_id: 2, b: []}, + {_id: 3, b: []}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}, {$sort: {_id: 1}}]).toArray(); +assert.eq(results, expectedResults); + +// Create filter with path expressions inside $let expression. +results = coll.aggregate([ + { + $project: { + b: { + $let: { + vars: { + value: '$d' + }, + in: { + $filter: { + input: '$a', + cond: {$gte: [{$add: ['$c', '$$this']}, '$$value']} + } + } + } + } + } + }, + {$sort: {_id: 1}} +]).toArray(); +expectedResults = [ + {_id: 0, b: [2, 3, 4, 5]}, + {_id: 1, b: [2]}, + {_id: 2, b: []}, + {_id: 3, b: [4]}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +assert.eq(results, expectedResults); + +// Create filter that uses the $and and $or. +filterDoc = { + input: '$a', + cond: {$or: [{$and: [{$gt: ['$$this', 1]}, {$lt: ['$$this', 3]}]}, {$eq: ['$$this', 5]}]} +}; +expectedResults = [ + {_id: 0, b: [2, 5]}, + {_id: 1, b: [2]}, + {_id: 2, b: []}, + {_id: 3, b: []}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}, {$sort: {_id: 1}}]).toArray(); +assert.eq(results, expectedResults); + +// Nested $filter expression. Queries below do not make sense from the user perspective, but allow +// us to test complex SBE trees generated by expressions like $and, $or, $cond and $switch with +// $filter inside them. + +// Create filter as an argument to $and and $or expressions. +expectedResults = [ + {_id: 0, b: true}, + {_id: 1, b: true}, + {_id: 2, b: true}, + {_id: 3, b: true}, + {_id: 4, b: false}, + {_id: 5, b: false}, + {_id: 6, b: false}, +]; +results = coll.aggregate([ + { + $project: { + b: { + $or: [ + { + $and: [ + { + $filter: { + input: '$a', + cond: { + $or: [ + { + $and: [ + {$gt: ['$$this', 1]}, + {$lt: ['$$this', 3]} + ] + }, + {$eq: ['$$this', 5]} + ] + } + } + }, + '$d' + ] + }, + {$filter: {input: '$a', cond: {$eq: ['$$this', 1]}}} + ] + } + } + }, + {$sort: {_id: 1}} + ]) + .toArray(); +assert.eq(results, expectedResults); + +// Create filter as an argument to $cond expression. +expectedResults = [ + {_id: 0, b: [2]}, + {_id: 1, b: [2]}, + {_id: 2, b: []}, + {_id: 3, b: []}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +results = coll.aggregate([ + { + $project: { + b: { + $cond: { + if: {$filter: {input: '$a', cond: {$eq: ['$$this', 1]}}}, + then: {$filter: {input: '$a', cond: {$eq: ['$$this', 2]}}}, + else: {$filter: {input: '$a', cond: {$eq: ['$$this', 3]}}} + } + } + } + }, + {$sort: {_id: 1}} + ]) + .toArray(); +assert.eq(results, expectedResults); + +// Create filter as an argument to $switch expression. +expectedResults = [ + {_id: 0, b: [2]}, + {_id: 1, b: [2]}, + {_id: 2, b: []}, + {_id: 3, b: []}, + {_id: 4, b: null}, + {_id: 5, b: null}, + {_id: 6, b: null}, +]; +results = + coll.aggregate([ + { + $project: { + b: { + $switch: { + branches: [ + { + case: {$filter: {input: '$a', cond: {$eq: ['$$this', 1]}}}, + then: {$filter: {input: '$a', cond: {$eq: ['$$this', 2]}}} + }, + { + case: {$filter: {input: '$a', cond: {$eq: ['$$this', 3]}}}, + then: {$filter: {input: '$a', cond: {$eq: ['$$this', 4]}}} + } + ], + default: {$filter: {input: '$a', cond: {$eq: ['$$this', 5]}}} + } + } + } + }, + {$sort: {_id: 1}} + ]) + .toArray(); +assert.eq(results, expectedResults); + +// Invalid filter expressions. + +// '$filter' is not a document. +var filterDoc = 'string'; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28646); + +// Extra field(s). +filterDoc = {input: '$a', as: 'x', cond: true, extra: 1}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28647); + +// Missing 'input'. +filterDoc = { + as: 'x', + cond: true +}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28648); + +// Missing 'cond'. +filterDoc = {input: '$a', as: 'x'}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28650); + +// 'as' is not a valid variable name. +filterDoc = {input: '$a', as: '$x', cond: true}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 16867); + +// 'input' is not an array. +filterDoc = {input: 'string', as: 'x', cond: true}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28651); + +// 'cond' uses undefined variable name. +filterDoc = { + input: '$a', + cond: {$eq: [1, '$$var']} +}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 17276); + +assert(coll.drop()); +assert.commandWorked(coll.insert({a: 'string'})); +filterDoc = {input: '$a', as: 'x', cond: true}; +assertErrorCode(coll, [{$project: {b: {$filter: filterDoc}}}], 28651); + +// Create filter with non-bool predicate. +assert(coll.drop()); +const date = new Date(); +assert.commandWorked( + coll.insert({_id: 0, a: [date, null, undefined, 0, false, NumberDecimal('1'), [], {c: 3}]})); +expectedResults = [ + {_id: 0, b: [date, NumberDecimal('1'), [], {c: 3}]}, +]; +results = + coll.aggregate([{$project: {b: {$filter: {input: '$a', as: 'x', cond: '$$x'}}}}]).toArray(); +assert.eq(results, expectedResults); + +// Create filter with deep path expressions. +assert(coll.drop()); +assert.commandWorked(coll.insert({ + _id: 0, + a: [ + {b: {c: {d: 1}}}, + {b: {c: {d: 2}}}, + {b: {c: {d: 3}}}, + {b: {c: {d: 4}}}, + ] +})); + +filterDoc = { + input: '$a', + cond: {$gt: ['$$this.b.c.d', 2]} +}; +expectedResults = [ + {_id: 0, b: [{b: {c: {d: 3}}}, {b: {c: {d: 4}}}]}, +]; +results = coll.aggregate([{$project: {b: {$filter: filterDoc}}}]).toArray(); +assert.eq(results, expectedResults); + +// Create nested filter. +assert(coll.drop()); +assert.commandWorked(coll.insert({_id: 0, a: [[1, 2, 3], null, [4, 5, 6]]})); + +expectedResults = [ + {_id: 0, b: [[1, 2, 3], [4, 5, 6]]}, +]; +results = coll.aggregate([{ + $project: { + b: { + $filter: { + input: '$a', + cond: {$filter: {input: '$$this', cond: {$gt: ['$$this', 3]}}} + } + } + } + }]) + .toArray(); +assert.eq(results, expectedResults); + +// Test short-circuiting in $and and $or inside $filter expression. +coll.drop(); +assert.commandWorked(coll.insert({_id: 0, a: [-1, -2, -3, -4]})); + +// Lookup of '$POISON' field will always fail with this fail point enabled. +assert.commandWorked( + db.adminCommand({configureFailPoint: "failOnPoisonedFieldLookup", mode: "alwaysOn"})); + +// Create filter with $and expression containing '$POISON' in it. +expectedResults = [ + {_id: 0, b: []}, +]; +results = + coll.aggregate([{ + $project: + {b: {$filter: {input: '$a', cond: {$and: [{$gt: ['$$this', 0]}, '$POISON']}}}} + }]) + .toArray(); +assert.eq(results, expectedResults); + +// Create filter with $or expression containing '$POISON' in it. +expectedResults = [ + {_id: 0, b: [-1, -2, -3, -4]}, +]; +results = + coll.aggregate([{ + $project: + {b: {$filter: {input: '$a', cond: {$or: [{$lt: ['$$this', 0]}, '$POISON']}}}} + }]) + .toArray(); +assert.eq(results, expectedResults); + +// Create filter with $and expression containing invalid call to $ln in it. +expectedResults = [ + {_id: 0, b: []}, +]; +results = + coll.aggregate([{ + $project: + {b: {$filter: {input: '$a', cond: {$and: [{$gt: ['$$this', 0]}, {$ln: '$$this'}]}}}} + }]) + .toArray(); +assert.eq(results, expectedResults); + +// Create filter with $or expression containing invalid call to $ln in it. +expectedResults = [ + {_id: 0, b: [-1, -2, -3, -4]}, +]; +results = + coll.aggregate([{ + $project: + {b: {$filter: {input: '$a', cond: {$or: [{$lt: ['$$this', 0]}, {$ln: '$$this'}]}}}} + }]) + .toArray(); +assert.eq(results, expectedResults); +}()); diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js index 8cda6fd7113..ed468f971e9 100644 --- a/jstests/libs/sbe_assert_error_override.js +++ b/jstests/libs/sbe_assert_error_override.js @@ -21,6 +21,7 @@ // Below is the list of known equivalent error code groups. As new groups of equivalent error codes // are discovered, they should be added to this list. const equivalentErrorCodesList = [ + [28651, 5073201], [16020, 5066300], [16007, 5066300], [16608, 4848401], diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript index 387424a0b45..203f5e396b1 100644 --- a/src/mongo/db/SConscript +++ b/src/mongo/db/SConscript @@ -1198,6 +1198,7 @@ env.Library( 'query/sbe_stage_builder_filter.cpp', 'query/sbe_stage_builder_index_scan.cpp', 'query/sbe_stage_builder_projection.cpp', + 'query/sbe_stage_builder_helpers.cpp', 'query/sbe_sub_planner.cpp', 'query/stage_builder_util.cpp', 'query/wildcard_multikey_paths.cpp', diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 74ed09c8dfa..bd711560466 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -1472,6 +1472,10 @@ public: return visitor->visit(this); } + Variables::Id getVariableId() const { + return _varId; + } + protected: void _doAddDependencies(DepsTracker* deps) const final; diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index 3d857351f05..16cb27c5027 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -47,6 +47,7 @@ #include "mongo/db/pipeline/expression_visitor.h" #include "mongo/db/pipeline/expression_walker.h" #include "mongo/db/query/projection_parser.h" +#include "mongo/db/query/sbe_stage_builder_helpers.h" #include "mongo/util/str.h" namespace mongo::stage_builder { @@ -101,6 +102,15 @@ struct ExpressionVisitorContext { nextBranchResultSlot(nextBranchResultSlot) {} }; + struct FilterExpressionEvalFrame { + std::unique_ptr<sbe::PlanStage> traverseStage; + sbe::value::SlotVector relevantSlots; + + FilterExpressionEvalFrame(std::unique_ptr<sbe::PlanStage> traverseStage, + const sbe::value::SlotVector& relevantSlots) + : traverseStage(std::move(traverseStage)), relevantSlots(relevantSlots) {} + }; + ExpressionVisitorContext(std::unique_ptr<sbe::PlanStage> inputStage, sbe::value::SlotIdGenerator* slotIdGenerator, sbe::value::FrameIdGenerator* frameIdGenerator, @@ -282,7 +292,10 @@ struct ExpressionVisitorContext { std::map<Variables::Id, sbe::value::SlotId> environment; std::stack<VarsFrame> varsFrameStack; + // TODO SERVER-51356: Replace these stacks with single stack of evaluation frames. + std::stack<FilterExpressionEvalFrame> filterExpressionEvalFrameStack; std::stack<LogicalExpressionEvalFrame> logicalExpressionEvalFrameStack; + // See the comment above the generateExpression() declaration for an explanation of the // 'relevantSlots' list. sbe::value::SlotVector* relevantSlots; @@ -400,23 +413,6 @@ std::pair<sbe::value::SlotId, std::unique_ptr<sbe::PlanStage>> generateTraverse( } /** - * Generates an EExpression that checks if the input expression is null or missing. - */ -std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::logicOr, - sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(var.clone()))), - sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(var.clone()))); -} - -std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::FrameId frameId, - const sbe::value::SlotId slotId) { - sbe::EVariable var{frameId, slotId}; - return generateNullOrMissing(var); -} - -/** * Generates an EExpression that converts the input to upper or lower case. */ void generateStringCaseConversionExpression(ExpressionVisitorContext* _context, @@ -454,99 +450,6 @@ void generateStringCaseConversionExpression(ExpressionVisitorContext* _context, sbe::makeE<sbe::ELocalBind>(frameId, std::move(str), std::move(totalCaseConversionExpr))); } -/** - * Generates an EExpression that checks if the input expression is a non-numeric type _assuming - * that_ it has already been verified to be neither null nor missing. - */ -std::unique_ptr<sbe::EExpression> generateNonNumericCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(var.clone()))); -} - -/** - * Generates an EExpression that checks if the input expression is the value NumberLong(-2^64). - */ -std::unique_ptr<sbe::EExpression> generateLongLongMinCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::logicAnd, - sbe::makeE<sbe::ETypeMatch>(var.clone(), - MatcherTypeSet{BSONType::NumberLong}.getBSONTypeMask()), - sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::eq, - var.clone(), - sbe::makeE<sbe::EConstant>( - sbe::value::TypeTags::NumberInt64, - sbe::value::bitcastFrom(std::numeric_limits<int64_t>::min())))); -} - -/** - * Generates an EExpression that checks if the input expression is NaN _assuming that_ it has - * already been verified to be numeric. - */ -std::unique_ptr<sbe::EExpression> generateNaNCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EFunction>("isNaN", sbe::makeEs(var.clone())); -} - -/** - * Generates an EExpression that checks if the input expression is a non-positive number (i.e. <= 0) - * _assuming that_ it has already been verified to be numeric. - */ -std::unique_ptr<sbe::EExpression> generateNonPositiveCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::EPrimBinary::lessEq, - var.clone(), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom(0))); -} - -/** - * Generates an EExpression that checks if the input expression is a negative (i.e., < 0) number - * _assuming that_ it has already been verified to be numeric. - */ -std::unique_ptr<sbe::EExpression> generateNegativeCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimBinary>( - sbe::EPrimBinary::EPrimBinary::less, - var.clone(), - sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom(0))); -} - -/** - * Generates an EExpression that checks if the input expression is _not_ an object, _assuming that_ - * it has already been verified to be neither null nor missing. - */ -std::unique_ptr<sbe::EExpression> generateNonObjectCheck(const sbe::EVariable& var) { - return sbe::makeE<sbe::EPrimUnary>( - sbe::EPrimUnary::logicNot, - sbe::makeE<sbe::EFunction>("isObject", sbe::makeEs(var.clone()))); -} - -/** - * A pair representing a 1) true/false condition and 2) the value that should be returned if that - * condition evaluates to true. - */ -using CaseValuePair = - std::pair<std::unique_ptr<sbe::EExpression>, std::unique_ptr<sbe::EExpression>>; - -/** - * Convert a list of CaseValuePairs into a chain of EIf expressions, with the final else case - * evaluating to the 'defaultValue' EExpression. - */ -template <typename... Ts> -std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(Ts... cases); - -template <typename... Ts> -std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(CaseValuePair headCase, Ts... rest) { - return sbe::makeE<sbe::EIf>(std::move(headCase.first), - std::move(headCase.second), - buildMultiBranchConditional(std::move(rest)...)); -} - -template <> -std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( - std::unique_ptr<sbe::EExpression> defaultCase) { - return defaultCase; -} - class ExpressionPreVisitor final : public ExpressionVisitor { public: ExpressionPreVisitor(ExpressionVisitorContext* context) : _context{context} {} @@ -752,7 +655,21 @@ public: void visit(ExpressionDivide* expr) final {} void visit(ExpressionExp* expr) final {} void visit(ExpressionFieldPath* expr) final {} - void visit(ExpressionFilter* expr) final {} + void visit(ExpressionFilter* expr) final { + // This visitor executes after visiting the expression that will evaluate to the array for + // filtering and before visiting the filter condition expression. + auto variableId = expr->getVariableId(); + invariant(_context->environment.find(variableId) == _context->environment.end()); + + auto currentElementSlot = _context->slotIdGenerator->generate(); + _context->environment.insert({variableId, currentElementSlot}); + + // Temporarily reset 'traverseStage' with limit 1/coscan tree to prevent from being + // rewritten by filter predicate's generated sub-tree. + _context->filterExpressionEvalFrameStack.emplace(std::move(_context->traverseStage), + *_context->relevantSlots); + _context->traverseStage = makeLimitCoScanTree(_context->planNodeId); + } void visit(ExpressionFloor* expr) final {} void visit(ExpressionIfNull* expr) final {} void visit(ExpressionIn* expr) final {} @@ -1741,7 +1658,141 @@ public: _context->relevantSlots->push_back(outputSlot); } void visit(ExpressionFilter* expr) final { - unsupportedExpression("$filter"); + _context->ensureArity(2); + + auto filterPredicate = _context->popExpr(); + auto input = _context->popExpr(); + + // Extract 'traverseStage' generated for filter predicate. + auto filterTraverseStage = std::move(_context->traverseStage); + + // Restore old value of 'traverseStage' and 'relevantSlots' after filter predicate tree + // was built. + auto& filterPredicateEvalFrame = _context->filterExpressionEvalFrameStack.top(); + _context->traverseStage = std::move(filterPredicateEvalFrame.traverseStage); + *_context->relevantSlots = filterPredicateEvalFrame.relevantSlots; + _context->filterExpressionEvalFrameStack.pop(); + + // Filter predicate of $filter expression expects current array element to be stored in the + // specific variable. We already allocated slot for it in the "in" visitor, now we just need + // to retrieve it from the environment. + // This slot will be used in the traverse stage twice - to store the input array and to + // store current element in this array. + auto currentElementVariable = expr->getVariableId(); + invariant(_context->environment.count(currentElementVariable)); + auto inputArraySlot = _context->environment.at(currentElementVariable); + + // We no longer need this mapping because filter predicate which expects it was already + // compiled. + _context->environment.erase(currentElementVariable); + + // Construct 'from' branch of traverse stage. SBE tree stored in 'fromBranch' variable looks + // like this: + // + // project inputIsNotNullishSlot = !(isNull(inputArraySlot) || !exists(inputArraySlot)) + // project inputArraySlot = ( + // let inputRef = input + // in + // if isArray(inputRef) || isNull(inputRef) || !exists(inputRef) + // inputRef + // else + // fail() + // ) + // _context->traverseStage + auto frameId = _context->frameIdGenerator->generate(); + auto binds = sbe::makeEs(std::move(input)); + sbe::EVariable inputRef(frameId, 0); + + auto inputIsArrayOrNullish = sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::logicOr, + generateNullOrMissing(inputRef), + sbe::makeE<sbe::EFunction>("isArray", sbe::makeEs(inputRef.clone()))); + auto checkInputArrayType = + sbe::makeE<sbe::EIf>(std::move(inputIsArrayOrNullish), + inputRef.clone(), + sbe::makeE<sbe::EFail>(ErrorCodes::Error{5073201}, + "input to $filter must be an array")); + auto inputArray = + sbe::makeE<sbe::ELocalBind>(frameId, std::move(binds), std::move(checkInputArrayType)); + + sbe::EVariable inputArrayVariable{inputArraySlot}; + auto projectInputArray = sbe::makeProjectStage(std::move(_context->traverseStage), + _context->planNodeId, + inputArraySlot, + std::move(inputArray)); + + auto inputIsNotNullish = makeNot(generateNullOrMissing(inputArrayVariable)); + auto inputIsNotNullishSlot = _context->slotIdGenerator->generate(); + auto fromBranch = sbe::makeProjectStage(std::move(projectInputArray), + _context->planNodeId, + inputIsNotNullishSlot, + std::move(inputIsNotNullish)); + + // Construct 'in' branch of traverse stage. SBE tree stored in 'inBranch' variable looks + // like this: + // + // cfilter Variable{inputIsNotNullishSlot} + // filter filterPredicate + // filterTraverseStage + // + // Filter predicate can return non-boolean values. To fix this, we generate expression to + // coerce it to bool type. + frameId = _context->frameIdGenerator->generate(); + auto boolFilterPredicate = + sbe::makeE<sbe::ELocalBind>(frameId, + sbe::makeEs(std::move(filterPredicate)), + generateCoerceToBoolExpression(sbe::EVariable{frameId, 0})); + auto filterWithPredicate = sbe::makeS<sbe::FilterStage<false>>( + std::move(filterTraverseStage), std::move(boolFilterPredicate), _context->planNodeId); + + // If input array is null or missing, we do not evaluate filter predicate and return EOF. + auto innerBranch = + sbe::makeS<sbe::FilterStage<true>>(std::move(filterWithPredicate), + sbe::makeE<sbe::EVariable>(inputIsNotNullishSlot), + _context->planNodeId); + + // Relevant slots from the _context->traverseStage might be used in the traverse 'in' branch + // by filter predicate through path expressions and variables. We need to pass them + // explicitly as correlated to traverse 'from' branch. + auto outerCorrelatedSlots = *_context->relevantSlots; + + // Add all variables from the environment. + for (const auto& item : _context->environment) { + outerCorrelatedSlots.push_back(item.second); + } + + // inputIsNotNullishSlot is used explicitly by cfilter stage added on top of traverse 'in' + // branch. + outerCorrelatedSlots.push_back(inputIsNotNullishSlot); + + // Construct traverse stage with the following slots: + // * inputArraySlot - slot containing input array of $filter expression + // * filteredArraySlot - slot containing the array with items on which filter predicate has + // evaluated to true + // * inputArraySlot - slot where 'in' branch of traverse stage stores current array + // element if it satisfies the filter predicate + auto filteredArraySlot = _context->slotIdGenerator->generate(); + auto traverseStage = + sbe::makeS<sbe::TraverseStage>(std::move(fromBranch), + std::move(innerBranch), + inputArraySlot /* inField */, + filteredArraySlot /* outField */, + inputArraySlot /* outFieldInner */, + std::move(outerCorrelatedSlots) /* outerCorrelated */, + nullptr /* foldExpr */, + nullptr /* finalExpr */, + _context->planNodeId, + 1 /* nestedArraysDepth */); + + // If input array is null or missing, 'in' stage of traverse will return EOF. In this case + // traverse sets output slot (filteredArraySlot) to Nothing. We replace it with Null to + // match $filter expression behaviour. + auto result = sbe::makeE<sbe::EFunction>( + "fillEmpty", + sbe::makeEs(sbe::makeE<sbe::EVariable>(filteredArraySlot), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Null, 0))); + + _context->pushExpr(std::move(result), std::move(traverseStage)); } void visit(ExpressionFloor* expr) final { auto frameId = _context->frameIdGenerator->generate(); diff --git a/src/mongo/db/query/sbe_stage_builder_filter.cpp b/src/mongo/db/query/sbe_stage_builder_filter.cpp index 070ddcce12d..c963307114d 100644 --- a/src/mongo/db/query/sbe_stage_builder_filter.cpp +++ b/src/mongo/db/query/sbe_stage_builder_filter.cpp @@ -70,10 +70,12 @@ #include "mongo/db/matcher/schema/expression_internal_schema_unique_items.h" #include "mongo/db/matcher/schema/expression_internal_schema_xor.h" #include "mongo/db/query/sbe_stage_builder_expression.h" +#include "mongo/db/query/sbe_stage_builder_helpers.h" #include "mongo/util/str.h" namespace mongo::stage_builder { namespace { + /** * EvalExpr is a wrapper around an EExpression that can also carry a SlotId. */ @@ -159,21 +161,8 @@ struct EvalFrame { /** * Helper functions for building common EExpressions and PlanStage trees. */ -EvalStage makeLimitCoScanTree(PlanNodeId planNodeId, long long limit = 1) { - return {sbe::makeS<sbe::LimitSkipStage>( - sbe::makeS<sbe::CoScanStage>(planNodeId), limit, boost::none, planNodeId), - sbe::makeSV()}; -} - -std::unique_ptr<sbe::EExpression> makeNot(std::unique_ptr<sbe::EExpression> e) { - return sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, std::move(e)); -} - -std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpression> e) { - using namespace std::literals; - return sbe::makeE<sbe::EFunction>( - "fillEmpty"sv, - sbe::makeEs(std::move(e), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean, 0))); +EvalStage makeLimitCoScanStage(PlanNodeId planNodeId, long long limit = 1) { + return {makeLimitCoScanTree(planNodeId, limit), sbe::makeSV()}; } template <bool IsConst> @@ -182,7 +171,7 @@ EvalStage makeFilter(EvalStage stage, PlanNodeId planNodeId) { if (!stage.stage) { // If 'stage' is null, set it to be a limit-1/coscan tree. - stage = makeLimitCoScanTree(planNodeId); + stage = makeLimitCoScanStage(planNodeId); } return {sbe::makeS<sbe::FilterStage<IsConst>>( @@ -194,7 +183,7 @@ template <typename... Ts> EvalStage makeProject(EvalStage stage, PlanNodeId planNodeId, Ts&&... pack) { if (!stage.stage) { // If 'stage' is null, set it to be a limit-1/coscan tree. - stage = makeLimitCoScanTree(planNodeId); + stage = makeLimitCoScanStage(planNodeId); } auto outSlots = std::move(stage.outSlots); @@ -219,12 +208,12 @@ EvalStage makeTraverse(EvalStage outer, boost::optional<size_t> nestedArraysDepth) { if (!outer.stage) { // If 'outer' is null, set it to be a limit-1/coscan tree. - outer = makeLimitCoScanTree(planNodeId); + outer = makeLimitCoScanStage(planNodeId); } if (!inner.stage) { // If 'inner' is null, set it to be a limit-1/coscan tree. - inner = makeLimitCoScanTree(planNodeId); + inner = makeLimitCoScanStage(planNodeId); } sbe::value::SlotVector outerCorrelated; @@ -1216,7 +1205,7 @@ public: invariant(frame.inputSlot == _context->inputSlot); if (!frame.stage.stage) { - frame.stage = makeLimitCoScanTree(_context->planNodeId); + frame.stage = makeLimitCoScanStage(_context->planNodeId); } auto&& [_, expr, stage] = generateExpression(_context->opCtx, diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_helpers.cpp new file mode 100644 index 00000000000..33546b836a5 --- /dev/null +++ b/src/mongo/db/query/sbe_stage_builder_helpers.cpp @@ -0,0 +1,120 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/query/sbe_stage_builder_helpers.h" + +#include "mongo/db/exec/sbe/expressions/expression.h" +#include "mongo/db/exec/sbe/stages/co_scan.h" +#include "mongo/db/exec/sbe/stages/limit_skip.h" +#include "mongo/db/matcher/matcher_type_set.h" + +namespace mongo::stage_builder { + +std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::logicOr, + sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, + sbe::makeE<sbe::EFunction>("exists", sbe::makeEs(var.clone()))), + sbe::makeE<sbe::EFunction>("isNull", sbe::makeEs(var.clone()))); +} + +std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::FrameId frameId, + const sbe::value::SlotId slotId) { + sbe::EVariable var{frameId, slotId}; + return generateNullOrMissing(var); +} + +std::unique_ptr<sbe::EExpression> generateNonNumericCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimUnary>( + sbe::EPrimUnary::logicNot, + sbe::makeE<sbe::EFunction>("isNumber", sbe::makeEs(var.clone()))); +} + +std::unique_ptr<sbe::EExpression> generateLongLongMinCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::logicAnd, + sbe::makeE<sbe::ETypeMatch>(var.clone(), + MatcherTypeSet{BSONType::NumberLong}.getBSONTypeMask()), + sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::eq, + var.clone(), + sbe::makeE<sbe::EConstant>( + sbe::value::TypeTags::NumberInt64, + sbe::value::bitcastFrom(std::numeric_limits<int64_t>::min())))); +} + +std::unique_ptr<sbe::EExpression> generateNaNCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EFunction>("isNaN", sbe::makeEs(var.clone())); +} + +std::unique_ptr<sbe::EExpression> generateNonPositiveCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::EPrimBinary::lessEq, + var.clone(), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom(0))); +} + +std::unique_ptr<sbe::EExpression> generateNegativeCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimBinary>( + sbe::EPrimBinary::EPrimBinary::less, + var.clone(), + sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::NumberInt32, sbe::value::bitcastFrom(0))); +} + +std::unique_ptr<sbe::EExpression> generateNonObjectCheck(const sbe::EVariable& var) { + return sbe::makeE<sbe::EPrimUnary>( + sbe::EPrimUnary::logicNot, + sbe::makeE<sbe::EFunction>("isObject", sbe::makeEs(var.clone()))); +} + +template <> +std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( + std::unique_ptr<sbe::EExpression> defaultCase) { + return defaultCase; +} + +std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit) { + return sbe::makeS<sbe::LimitSkipStage>( + sbe::makeS<sbe::CoScanStage>(planNodeId), limit, boost::none, planNodeId); +} + +std::unique_ptr<sbe::EExpression> makeNot(std::unique_ptr<sbe::EExpression> e) { + return sbe::makeE<sbe::EPrimUnary>(sbe::EPrimUnary::logicNot, std::move(e)); +} + +std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpression> e) { + using namespace std::literals; + return sbe::makeE<sbe::EFunction>( + "fillEmpty"sv, + sbe::makeEs(std::move(e), sbe::makeE<sbe::EConstant>(sbe::value::TypeTags::Boolean, 0))); +} + +} // namespace mongo::stage_builder diff --git a/src/mongo/db/query/sbe_stage_builder_helpers.h b/src/mongo/db/query/sbe_stage_builder_helpers.h new file mode 100644 index 00000000000..7e6b0ffc978 --- /dev/null +++ b/src/mongo/db/query/sbe_stage_builder_helpers.h @@ -0,0 +1,125 @@ +/** + * Copyright (C) 2020-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <memory> +#include <string> +#include <utility> + +#include "mongo/db/exec/sbe/expressions/expression.h" +#include "mongo/db/query/stage_types.h" + +namespace mongo::stage_builder { + +/** + * Generates an EExpression that checks if the input expression is null or missing. + */ +std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::EVariable& var); + +std::unique_ptr<sbe::EExpression> generateNullOrMissing(const sbe::FrameId frameId, + const sbe::value::SlotId slotId); + +/** + * Generates an EExpression that checks if the input expression is a non-numeric type _assuming + * that_ it has already been verified to be neither null nor missing. + */ +std::unique_ptr<sbe::EExpression> generateNonNumericCheck(const sbe::EVariable& var); + +/** + * Generates an EExpression that checks if the input expression is the value NumberLong(-2^64). + */ +std::unique_ptr<sbe::EExpression> generateLongLongMinCheck(const sbe::EVariable& var); + +/** + * Generates an EExpression that checks if the input expression is NaN _assuming that_ it has + * already been verified to be numeric. + */ +std::unique_ptr<sbe::EExpression> generateNaNCheck(const sbe::EVariable& var); + +/** + * Generates an EExpression that checks if the input expression is a non-positive number (i.e. <= 0) + * _assuming that_ it has already been verified to be numeric. + */ +std::unique_ptr<sbe::EExpression> generateNonPositiveCheck(const sbe::EVariable& var); + +/** + * Generates an EExpression that checks if the input expression is a negative (i.e., < 0) number + * _assuming that_ it has already been verified to be numeric. + */ +std::unique_ptr<sbe::EExpression> generateNegativeCheck(const sbe::EVariable& var); + +/** + * Generates an EExpression that checks if the input expression is _not_ an object, _assuming that_ + * it has already been verified to be neither null nor missing. + */ +std::unique_ptr<sbe::EExpression> generateNonObjectCheck(const sbe::EVariable& var); + +/** + * A pair representing a 1) true/false condition and 2) the value that should be returned if that + * condition evaluates to true. + */ +using CaseValuePair = + std::pair<std::unique_ptr<sbe::EExpression>, std::unique_ptr<sbe::EExpression>>; + +/** + * Convert a list of CaseValuePairs into a chain of EIf expressions, with the final else case + * evaluating to the 'defaultValue' EExpression. + */ +template <typename... Ts> +std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(Ts... cases); + +template <typename... Ts> +std::unique_ptr<sbe::EExpression> buildMultiBranchConditional(CaseValuePair headCase, Ts... rest) { + return sbe::makeE<sbe::EIf>(std::move(headCase.first), + std::move(headCase.second), + buildMultiBranchConditional(std::move(rest)...)); +} + +template <> +std::unique_ptr<sbe::EExpression> buildMultiBranchConditional( + std::unique_ptr<sbe::EExpression> defaultCase); + +/** + * Create tree consisting of coscan stage followed by limit stage. + */ +std::unique_ptr<sbe::PlanStage> makeLimitCoScanTree(PlanNodeId planNodeId, long long limit = 1); + +/** + * Wrap expression into logical negation. + */ +std::unique_ptr<sbe::EExpression> makeNot(std::unique_ptr<sbe::EExpression> e); + +/** + * Check if expression returns Nothing and return boolean false if so. Otherwise, return the + * expression. + */ +std::unique_ptr<sbe::EExpression> makeFillEmptyFalse(std::unique_ptr<sbe::EExpression> e); + +} // namespace mongo::stage_builder |