diff options
author | Davis Haupt <davis.haupt@mongodb.com> | 2022-11-01 19:35:57 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-11-01 20:45:01 +0000 |
commit | 464abab59fa63b45d187a4145dceec3281bbcf0d (patch) | |
tree | d059d053023fa0c95f301a63a87d03ac26d2c550 | |
parent | 5d04b5969c880cc7b7acba50490dc762c0964f97 (diff) | |
download | mongo-464abab59fa63b45d187a4145dceec3281bbcf0d.tar.gz |
SERVER-70306 support aggregate comparison operators in server-side encrypted range rewrite
-rw-r--r-- | src/mongo/db/query/fle/encrypted_predicate.h | 67 | ||||
-rw-r--r-- | src/mongo/db/query/fle/query_rewriter.cpp | 29 | ||||
-rw-r--r-- | src/mongo/db/query/fle/query_rewriter_test.cpp | 192 | ||||
-rw-r--r-- | src/mongo/db/query/fle/range_predicate.cpp | 128 | ||||
-rw-r--r-- | src/mongo/db/query/fle/range_predicate.h | 6 | ||||
-rw-r--r-- | src/mongo/db/query/fle/range_predicate_test.cpp | 154 |
6 files changed, 469 insertions, 107 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h index 5fcac966dbf..fc78aa28014 100644 --- a/src/mongo/db/query/fle/encrypted_predicate.h +++ b/src/mongo/db/query/fle/encrypted_predicate.h @@ -40,6 +40,7 @@ #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/query/fle/query_rewriter_interface.h" +#include "mongo/stdx/unordered_map.h" /** * This file contains an abstract class that describes rewrites on agg Expressions and @@ -194,33 +195,37 @@ private: * are keyed on the dynamic type for the Expression subclass. */ -using ExpressionToRewriteMap = stdx::unordered_map< - std::type_index, - std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>; +using ExpressionRewriteFunction = + std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>; +using ExpressionToRewriteMap = + stdx::unordered_map<std::type_index, std::vector<ExpressionRewriteFunction>>; extern ExpressionToRewriteMap aggPredicateRewriteMap; -using MatchTypeToRewriteMap = stdx::unordered_map< - MatchExpression::MatchType, - std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>; +using MatchRewriteFunction = + std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>; +using MatchTypeToRewriteMap = + stdx::unordered_map<MatchExpression::MatchType, std::vector<MatchRewriteFunction>>; extern MatchTypeToRewriteMap matchPredicateRewriteMap; /** * Register an agg rewrite if a condition is true at startup time. */ -#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \ - \ - invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \ - aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<Expression>(nullptr); \ - } \ - }; \ - } +#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \ + MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className##_##rewriteClass) \ + (InitializerContext*) { \ + if (aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()) { \ + aggPredicateRewriteMap[typeid(className)] = std::vector<ExpressionRewriteFunction>(); \ + } \ + aggPredicateRewriteMap[typeid(className)].push_back([](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<Expression>(nullptr); \ + } \ + }); \ + }; /** * Register an agg rewrite unconditionally. @@ -239,17 +244,21 @@ extern MatchTypeToRewriteMap matchPredicateRewriteMap; * Register a MatchExpression rewrite if a condition is true at startup time. */ #define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \ - MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \ - \ - invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \ - matchPredicateRewriteMap.end()); \ - matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \ - if (isEnabledExpr) { \ - return rewriteClass{rewriter}.rewrite(expr); \ - } else { \ - return std::unique_ptr<MatchExpression>(nullptr); \ - } \ - }; \ + MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType##_##rewriteClass) \ + (InitializerContext*) { \ + if (matchPredicateRewriteMap.find(MatchExpression::matchType) == \ + matchPredicateRewriteMap.end()) { \ + matchPredicateRewriteMap[MatchExpression::matchType] = \ + std::vector<MatchRewriteFunction>(); \ + } \ + matchPredicateRewriteMap[MatchExpression::matchType].push_back( \ + [](auto* rewriter, auto* expr) { \ + if (isEnabledExpr) { \ + return rewriteClass{rewriter}.rewrite(expr); \ + } else { \ + return std::unique_ptr<MatchExpression>(nullptr); \ + } \ + }); \ }; /** * Register a MatchExpression rewrite unconditionally. diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp index 441f436ec00..da4fb05ecc7 100644 --- a/src/mongo/db/query/fle/query_rewriter.cpp +++ b/src/mongo/db/query/fle/query_rewriter.cpp @@ -40,12 +40,15 @@ public: : queryRewriter(queryRewriter), exprRewrites(exprRewrites){}; std::unique_ptr<Expression> postVisit(Expression* exp) { - if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) { - auto expr = rewrite->second(queryRewriter, exp); - if (expr != nullptr) { - didRewrite = true; + if (auto rewriteEntry = exprRewrites.find(typeid(*exp)); + rewriteEntry != exprRewrites.end()) { + for (auto& rewrite : rewriteEntry->second) { + auto expr = rewrite(queryRewriter, exp); + if (expr != nullptr) { + didRewrite = true; + return expr; + } } - return expr; } return nullptr; } @@ -109,13 +112,17 @@ std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr) return nullptr; } default: { - if (auto rewrite = _matchRewrites.find(expr->matchType()); - rewrite != _matchRewrites.end()) { - auto rewritten = rewrite->second(this, expr); - if (rewritten != nullptr) { - _rewroteLastExpression = true; + if (auto rewriteEntry = _matchRewrites.find(expr->matchType()); + rewriteEntry != _matchRewrites.end()) { + for (auto& rewrite : rewriteEntry->second) { + auto rewritten = rewrite(this, expr); + // Only one rewrite can be applied to an expression, so return as soon as a + // rewrite returns something other than nullptr. + if (rewritten != nullptr) { + _rewroteLastExpression = true; + return rewritten; + } } - return rewritten; } return nullptr; } diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp index d779f6be1cc..057bae0dd2b 100644 --- a/src/mongo/db/query/fle/query_rewriter_test.cpp +++ b/src/mongo/db/query/fle/query_rewriter_test.cpp @@ -71,7 +71,7 @@ protected: if (!elt.isABSONObj()) { return false; } - return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd; + return elt.Obj().hasField("encrypt"_sd); } bool isPayload(const Value& v) const override { if (!v.isObject()) { @@ -97,9 +97,110 @@ protected: }; std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override { + auto eqMatch = dynamic_cast<ExpressionCompare*>(expr); + invariant(eqMatch); + // Only operate over equality comparisons. + if (eqMatch->getOp() != ExpressionCompare::EQ) { + return nullptr; + } + auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get()); + // If the comparison doesn't hold a constant, then don't rewrite. + if (!payload) { + return nullptr; + } + + // If the constant is not considered a payload, then don't rewrite. + if (!isPayload(payload->getValue())) { + return nullptr; + } + auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(), + ExpressionCompare::GT); + cmp->addOperand(eqMatch->getOperandList()[0]); + cmp->addOperand( + ExpressionConstant::create(eqMatch->getExpressionContext(), + payload->getValue().getDocument().getField("encrypt"))); + return cmp; + } + + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( + MatchExpression* expr) const override { + return nullptr; + } + + std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override { return nullptr; } +private: + // This method is not used in mock implementations of the EncryptedPredicate since isPayload(), + // which normally calls encryptedBinDataType(), is overridden to look for plain objects rather + // than BinData. Since this method is pure virtual on the superclass and needs to be + // implemented, it is set to kPlaceholder (0). + EncryptedBinDataType encryptedBinDataType() const override { + return EncryptedBinDataType::kPlaceholder; + } +}; + +// A second mock rewrite which replaces documents with the key "foo" into $lt operations. We need +// two different rewrites that are registered on the same operator to verify that all rewrites are +// iterated through. +class OtherMockPredicateRewriter : public fle::EncryptedPredicate { +public: + OtherMockPredicateRewriter(const fle::QueryRewriterInterface* rewriter) + : EncryptedPredicate(rewriter) {} + +protected: + bool isPayload(const BSONElement& elt) const override { + if (!elt.isABSONObj()) { + return false; + } + return elt.Obj().hasField("foo"_sd); + } + bool isPayload(const Value& v) const override { + if (!v.isObject()) { + return false; + } + return !v.getDocument().getField("foo").missing(); + } + + std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override { + return {}; + }; + + // Encrypted values will be rewritten from $eq to $lt. This is an arbitrary decision just to + // make sure that the rewrite works properly. + std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override { + invariant(expr->matchType() == MatchExpression::EQ); + auto eqMatch = static_cast<EqualityMatchExpression*>(expr); + if (!isPayload(eqMatch->getData())) { + return nullptr; + } + return std::make_unique<LTMatchExpression>(eqMatch->path(), + eqMatch->getData().Obj().firstElement()); + }; + + std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override { + auto eqMatch = dynamic_cast<ExpressionCompare*>(expr); + invariant(eqMatch); + if (eqMatch->getOp() != ExpressionCompare::EQ) { + return nullptr; + } + auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get()); + if (!payload) { + return nullptr; + } + + if (!isPayload(payload->getValue())) { + return nullptr; + } + auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(), + ExpressionCompare::LT); + cmp->addOperand(eqMatch->getOperandList()[0]); + cmp->addOperand(ExpressionConstant::create( + eqMatch->getExpressionContext(), payload->getValue().getDocument().getField("foo"))); + return cmp; + } + std::unique_ptr<MatchExpression> rewriteToRuntimeComparison( MatchExpression* expr) const override { return nullptr; @@ -111,7 +212,7 @@ protected: private: EncryptedBinDataType encryptedBinDataType() const override { - return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere. + return EncryptedBinDataType::kPlaceholder; } }; @@ -119,8 +220,17 @@ void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match, fle::ExpressionToRewriteMap& agg, fle::TagMap& tags, std::set<StringData>& encryptedFields) { - match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) { - return MockPredicateRewriter{rewriter}.rewrite(expr); + match[MatchExpression::EQ] = { + [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); }, + [&](auto* rewriter, auto* expr) { + return OtherMockPredicateRewriter{rewriter}.rewrite(expr); + }, + }; + agg[typeid(ExpressionCompare)] = { + [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); }, + [&](auto* rewriter, auto* expr) { + return OtherMockPredicateRewriter{rewriter}.rewrite(expr); + }, }; } @@ -137,9 +247,17 @@ public: return res ? res.value() : obj; } + BSONObj rewriteAggExpressionForTest(const BSONObj& obj) { + auto expr = Expression::parseExpression(&_expCtx, obj, _expCtx.variablesParseState); + auto result = rewriteExpression(expr.get()); + return result ? result->serialize(false).getDocument().toBson() + : expr->serialize(false).getDocument().toBson(); + } + private: fle::TagMap _tags; std::set<StringData> _encryptedFields; + ExpressionContextForTest _expCtx; }; class FLEServerRewriteTest : public unittest::Test { @@ -167,22 +285,53 @@ protected: ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \ } +#define ASSERT_AGG_EXPRESSION_REWRITE(input, expected) \ + auto actual = _mock->rewriteAggExpressionForTest(fromjson(input)); \ + ASSERT_BSONOBJ_EQ(actual, fromjson(expected)); + +#define TEST_FLE_REWRITE_AGG(name, input, expected) \ + TEST_F(FLEServerRewriteTest, name##_AggExpression) { \ + ASSERT_AGG_EXPRESSION_REWRITE(input, expected); \ + } + TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath, "{'user.ssn': {$eq: {encrypt: 2}}}", "{'user.ssn': {$gt: 2}}"); +TEST_FLE_REWRITE_AGG(TopLevel_DottedPath, + "{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}", + "{$gt: ['$user.ssn', {$const: 2}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted, "{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_BothEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {encrypt: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$gt: ['$age', {$const: 4}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted, "{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_PartlyEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator, "{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}", "{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}"); +TEST_FLE_REWRITE_AGG( + TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$lt: ['$age', {$const: {encrypt: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: {encrypt: " + "4}}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted, "{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}", "{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}"); @@ -191,6 +340,10 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Not, "{ssn: {$not: {$eq: {encrypt: 5}}}}", "{ssn: {$not: {$gt: 5}}}"); +TEST_FLE_REWRITE_AGG(TopLevel_Not, + "{$not: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}]}", + "{$not: [{$gt: ['$ssn', {$const: 2}]}]}") + TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}"); TEST_FLE_REWRITE_MATCH( @@ -198,6 +351,12 @@ TEST_FLE_REWRITE_MATCH( "{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}", "{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}"); +TEST_FLE_REWRITE_AGG(NestedConjunction, + "{$and: [{$and: [{$eq: ['$ssn', {$const: {encrypt: 2}}]},{$eq: ['$other', " + "'field']}]},{$eq: ['$age',{$const: {encrypt: 4}}]}]}", + "{$and: [{$and: [{$gt: ['$ssn', {$const: 2}]},{$eq: ['$other', " + "{$const: 'field'}]}]},{$gt: ['$age',{$const: 4}]}]}"); + TEST_FLE_REWRITE_MATCH(TopLevel_Nor, "{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", "{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); @@ -205,5 +364,30 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Nor, TEST_FLE_REWRITE_MATCH(TopLevel_Or, "{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}", "{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}"); + +TEST_FLE_REWRITE_AGG( + TopLevel_Or, + "{$or: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}, {$eq: ['$ssn', {$const: {encrypt: 4}}]}]}", + "{$or: [{$gt: ['$ssn', {$const: 2}]}, {$gt: ['$ssn', {$const: 4}]}]}") + + +// Test that the rewriter will work from any rewrite registered to an expression. The test rewriter +// has two rewrites registered on $eq. + +TEST_FLE_REWRITE_MATCH(OtherRewrite_Basic, "{'ssn': {$eq: {foo: 2}}}", "{'ssn': {$lt: 2}}"); + +TEST_FLE_REWRITE_AGG(OtherRewrite_Basic, + "{$eq: ['$user.ssn', {$const: {foo: 2}}]}", + "{$lt: ['$user.ssn', {$const: 2}]}"); + +TEST_FLE_REWRITE_MATCH(OtherRewrite_Conjunction_BothEncrypted, + "{$and: [{ssn: {encrypt: 2}}, {age: {foo: 4}}]}", + "{$and: [{ssn: {$gt: 2}}, {age: {$lt: 4}}]}"); + +TEST_FLE_REWRITE_AGG( + OtherRewrite_Conjunction_BothEncrypted, + "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {foo: " + "4}}]}]}", + "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: 4}]}]}"); } // namespace } // namespace mongo diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index 8c083b07dc1..dcd71777782 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare, + RangePredicate, + gFeatureFlagFLE2Range); + +namespace { +// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If +// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range +// predicate, then return null to the caller so that the rewrite can return null. +std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) { + auto children = [&]() { + if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) { + return betweenExpr->getChildren(); + } else { + auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr); + tassert(6720901, + "Range rewrite should only be called with $between or comparison operator.", + cmpExpr); + switch (cmpExpr->getOp()) { + case ExpressionCompare::GT: + case ExpressionCompare::GTE: + case ExpressionCompare::LT: + case ExpressionCompare::LTE: + return cmpExpr->getChildren(); + + case ExpressionCompare::EQ: + case ExpressionCompare::NE: + case ExpressionCompare::CMP: + return std::vector<boost::intrusive_ptr<Expression>>(); + } + } + return std::vector<boost::intrusive_ptr<Expression>>(); + }(); + if (children.empty()) { + return {nullptr, Value()}; + } + // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2. + auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); + uassert(6720903, "first argument should be a fieldpath", fieldpath); + auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); + uassert(6720904, "second argument should be a constant", secondArg); + auto payload = secondArg->getValue(); + return {children[0], payload}; +} +} // namespace + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + StringData path, ParsedFindRangePayload payload) const { + auto* expCtx = _rewriter->getExpressionContext(); + return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState), + payload); +} + +std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( + boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { + tassert(7030501, + "$internalFleBetween can only be generated from a non-stub payload.", + !payload.isStub()); + auto cm = payload.maxCounter; + ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); + std::vector<ConstDataRange> edcTokens; + std::transform(std::make_move_iterator(payload.edges.value().begin()), + std::make_move_iterator(payload.edges.value().end()), + std::back_inserter(edcTokens), + [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); + + auto* expCtx = _rewriter->getExpressionContext(); + return std::make_unique<ExpressionInternalFLEBetween>( + expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); +} + std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload); std::vector<PrfBlock> tags; @@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( return makeTagDisjunction(toBSONArray(generateTags(payload))); } -std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) { - auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr); - tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr); - auto children = betweenExpr->getChildren(); - uassert(6720902, "$between should have two children.", children.size() == 2); - - auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); - uassert(6720903, "first argument should be a fieldpath", fieldpath); - auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); - uassert(6720904, "second argument should be a constant", secondArg); - auto payload = secondArg->getValue(); - return {children[0], payload}; -} - std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { - auto [_, payload] = validateBetween(expr); + auto [fieldpath, payload] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(payload)) { return nullptr; } + if (isStub(std::ref(payload))) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } + auto tags = toValues(generateTags(std::ref(payload))); return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); } -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - StringData path, ParsedFindRangePayload payload) const { - auto* expCtx = _rewriter->getExpressionContext(); - return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString( - expCtx, path.toString(), expCtx->variablesParseState), - payload); -} - -std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload( - boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const { - tassert(7030501, - "$internalFleBetween can only be generated from a non-stub payload.", - !payload.isStub()); - auto cm = payload.maxCounter; - ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); - std::vector<ConstDataRange> edcTokens; - std::transform(std::make_move_iterator(payload.edges.value().begin()), - std::make_move_iterator(payload.edges.value().end()), - std::back_inserter(edcTokens), - [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); - - auto* expCtx = _rewriter->getExpressionContext(); - return std::make_unique<ExpressionInternalFLEBetween>( - expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens)); -} - std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( MatchExpression* expr) const { BSONElement ffp; @@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison( } std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const { - auto [fieldpath, ffp] = validateBetween(expr); + auto [fieldpath, ffp] = validateRangeOp(expr); + if (!fieldpath) { + return nullptr; + } if (!isPayload(ffp)) { return nullptr; } auto payload = parseFindPayload<ParsedFindRangePayload>(ffp); + if (payload.isStub()) { + return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true)); + } return fleBetweenFromPayload(fieldpath, payload); } } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h index bc2a47fa173..5fbd7484ad0 100644 --- a/src/mongo/db/query/fle/range_predicate.h +++ b/src/mongo/db/query/fle/range_predicate.h @@ -55,6 +55,12 @@ protected: return parsedPayload.isStub(); } + virtual bool isStub(Value elt) const { + auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(elt); + return parsedPayload.isStub(); + } + + private: EncryptedBinDataType encryptedBinDataType() const override { return EncryptedBinDataType::kFLE2FindRangePayload; diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp index f81723da5eb..61110c19041 100644 --- a/src/mongo/db/query/fle/range_predicate_test.cpp +++ b/src/mongo/db/query/fle/range_predicate_test.cpp @@ -68,7 +68,11 @@ protected: return isStubPayload; } - std::vector<PrfBlock> generateTags(BSONValue payload) const { + bool isStub(Value elt) const override { + return isStubPayload; + } + + std::vector<PrfBlock> generateTags(BSONValue payload) const override { return stdx::visit( OverloadedVisitor{[&](BSONElement p) { if (p.isABSONObj()) { @@ -126,8 +130,6 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) { TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) { RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}}; - auto expCtx = make_intrusive<ExpressionContextForTest>(); std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"}; @@ -159,29 +161,98 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) { } } +TEST_F(RangePredicateRewriteTest, AggRangeRewrite_Stub) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); + + { + auto input = fromjson(str::stream() << "{$between: [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = ExpressionConstant::create(&_expCtx, Value(true)); + + _predicate.isStubPayload = true; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = ExpressionConstant::create(&_expCtx, Value(true)); + + _predicate.isStubPayload = true; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } +} + TEST_F(RangePredicateRewriteTest, AggRangeRewrite) { - auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); - auto inputExpr = - ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + { + auto op = "$between"; + auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); - auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); + auto actual = _predicate.rewrite(inputExpr.get()); - auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + { + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = + fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); - ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), - expected->serialize(false).getDocument().toBson()); + auto actual = _predicate.rewrite(inputExpr.get()); + + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); + } + } } TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) { - auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); - auto inputExpr = - ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + { + auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); - auto expected = inputExpr; + auto expected = inputExpr; - _predicate.payloadValid = false; - auto actual = _predicate.rewrite(inputExpr.get()); - ASSERT(actual == nullptr); + _predicate.payloadValid = false; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual == nullptr); + } + { + auto ops = {"$gt", "$lt", "$gte", "$lte"}; + for (auto& op : ops) { + auto input = + fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}"); + auto inputExpr = + ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = inputExpr; + + _predicate.payloadValid = false; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual == nullptr); + } + } } BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) { @@ -219,6 +290,17 @@ std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx, return std::make_unique<ExpressionBetween>(expCtx, std::move(children)); } +std::unique_ptr<Expression> generateBetweenWithFFP( + ExpressionContext* expCtx, ExpressionCompare::CmpOp op, StringData path, int lb, int ub) { + auto ffp = Value(generateFFP(path, lb, ub, 0, 255).firstElement()); + auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp); + auto fieldpath = ExpressionFieldPath::createPathFromString( + expCtx, path.toString(), expCtx->variablesParseState); + std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath), + std::move(ffpExpr)}; + return std::make_unique<ExpressionCompare>(expCtx, op, std::move(children)); +} + TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { _mock.setForceEncryptedCollScanForTest(); auto expected = fromjson(R"({ @@ -273,9 +355,6 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) { _mock.setForceEncryptedCollScanForTest(); - auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35); - auto result = _predicate.rewrite(input.get()); - ASSERT(result); auto expected = fromjson(R"({ "$_internalFleBetween": { "field": "$age", @@ -309,7 +388,40 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) { } } })"); - ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + { + auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + } + { + auto ops = {ExpressionCompare::GT, + ExpressionCompare::GTE, + ExpressionCompare::LT, + ExpressionCompare::LTE}; + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result); + ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected); + } + } +} + + +TEST_F(RangePredicateRewriteTest, UnsupportedComparisonOps) { + auto ops = {ExpressionCompare::CMP, ExpressionCompare::EQ, ExpressionCompare::NE}; + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + } + _mock.setForceEncryptedCollScanForTest(); + for (auto& op : ops) { + auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35); + auto result = _predicate.rewrite(input.get()); + ASSERT(result == nullptr); + } } }; // namespace |