diff options
author | Davis Haupt <davis.haupt@mongodb.com> | 2022-10-03 17:27:11 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-10-03 19:15:24 +0000 |
commit | 46ad1bc5186a6cc4d449d077b12951a400595f4e (patch) | |
tree | faa45e1bb8c5ae93363b4334941b62c987ec4e99 /src/mongo/db | |
parent | 801e5203e0efcb62ffcef66e22da95d645b2dca2 (diff) | |
download | mongo-46ad1bc5186a6cc4d449d077b12951a400595f4e.tar.gz |
SERVER-67209 Server-side rewrite for aggregation to tag disjunction
Diffstat (limited to 'src/mongo/db')
-rw-r--r-- | src/mongo/db/query/fle/encrypted_predicate.cpp | 15 | ||||
-rw-r--r-- | src/mongo/db/query/fle/encrypted_predicate.h | 3 | ||||
-rw-r--r-- | src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h | 1 | ||||
-rw-r--r-- | src/mongo/db/query/fle/equality_predicate.cpp | 34 | ||||
-rw-r--r-- | src/mongo/db/query/fle/range_predicate.cpp | 27 | ||||
-rw-r--r-- | src/mongo/db/query/fle/range_predicate_test.cpp | 92 |
6 files changed, 117 insertions, 55 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate.cpp b/src/mongo/db/query/fle/encrypted_predicate.cpp index aeed43cc884..74d2ef8fbf4 100644 --- a/src/mongo/db/query/fle/encrypted_predicate.cpp +++ b/src/mongo/db/query/fle/encrypted_predicate.cpp @@ -72,5 +72,20 @@ std::vector<Value> toValues(std::vector<PrfBlock>&& vec) { } return output; } + +std::unique_ptr<Expression> makeTagDisjunction(ExpressionContext* expCtx, + std::vector<Value>&& tags) { + std::vector<boost::intrusive_ptr<Expression>> orListElems; + for (auto&& tagElt : tags) { + // ... and for each tag, construct expression {$in: [tag, + // "$__safeContent__"]}. + std::vector<boost::intrusive_ptr<Expression>> inVec{ + ExpressionConstant::create(expCtx, tagElt), + ExpressionFieldPath::createPathFromString( + expCtx, kSafeContent, expCtx->variablesParseState)}; + orListElems.push_back(make_intrusive<ExpressionIn>(expCtx, std::move(inVec))); + } + return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems)); +} } // namespace fle } // namespace mongo diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h index d6c379f6350..5fcac966dbf 100644 --- a/src/mongo/db/query/fle/encrypted_predicate.h +++ b/src/mongo/db/query/fle/encrypted_predicate.h @@ -72,6 +72,9 @@ T parseFindPayload(BSONValue payload) { payload); } +std::unique_ptr<Expression> makeTagDisjunction(ExpressionContext* expCtx, + std::vector<Value>&& tags); + /** * Convert a vector of PrfBlocks to a BSONArray for use in MatchExpression tag generation. */ diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h index 88a6c85983a..5c06e89a0da 100644 --- a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h +++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h @@ -102,5 +102,6 @@ public: protected: MockServerRewrite _mock{}; + ExpressionContextForTest _expCtx; }; } // namespace mongo::fle diff --git a/src/mongo/db/query/fle/equality_predicate.cpp b/src/mongo/db/query/fle/equality_predicate.cpp index a349962145b..88396949ee8 100644 --- a/src/mongo/db/query/fle/equality_predicate.cpp +++ b/src/mongo/db/query/fle/equality_predicate.cpp @@ -286,20 +286,8 @@ std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expressio std::vector<boost::intrusive_ptr<Expression>> orListElems; auto payload = constChild->getValue(); auto tags = toValues(generateTags(std::ref(payload))); - for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, - // "$__safeContent__"]}. - std::vector<boost::intrusive_ptr<Expression>> inVec{ - ExpressionConstant::create(_rewriter->getExpressionContext(), tagElt), - ExpressionFieldPath::createPathFromString( - _rewriter->getExpressionContext(), - kSafeContent, - _rewriter->getExpressionContext()->variablesParseState)}; - orListElems.push_back( - make_intrusive<ExpressionIn>(_rewriter->getExpressionContext(), std::move(inVec))); - } - auto disjunction = std::make_unique<ExpressionOr>(_rewriter->getExpressionContext(), - std::move(orListElems)); + auto disjunction = makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); + if (eqExpr->getOp() == ExpressionCompare::NE) { std::vector<boost::intrusive_ptr<Expression>> notChild{disjunction.release()}; return std::make_unique<ExpressionNot>(_rewriter->getExpressionContext(), @@ -312,27 +300,19 @@ std::unique_ptr<Expression> EqualityPredicate::rewriteToTagDisjunction(Expressio return nullptr; } auto& equalitiesList = inList->getChildren(); - std::vector<boost::intrusive_ptr<Expression>> orListElems; - auto expCtx = _rewriter->getExpressionContext(); + std::vector<Value> allTags; for (auto& equality : equalitiesList) { // For each expression representing a FleFindPayload... if (auto constChild = dynamic_cast<ExpressionConstant*>(equality.get())) { // ... rewrite the payload to a list of tags... auto payload = constChild->getValue(); auto tags = toValues(generateTags(std::ref(payload))); - for (auto&& tagElt : tags) { - // ... and for each tag, construct expression {$in: [tag, - // "$__safeContent__"]}. - std::vector<boost::intrusive_ptr<Expression>> inVec{ - ExpressionConstant::create(expCtx, tagElt), - ExpressionFieldPath::createPathFromString( - expCtx, kSafeContent, expCtx->variablesParseState)}; - orListElems.push_back( - make_intrusive<ExpressionIn>(expCtx, std::move(inVec))); - } + allTags.insert(allTags.end(), + std::make_move_iterator(tags.begin()), + std::make_move_iterator(tags.end())); } } - return std::make_unique<ExpressionOr>(expCtx, std::move(orListElems)); + return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(allTags)); } return nullptr; } diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index 0ebf33790fc..184f4f4baa2 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -35,6 +35,7 @@ #include "mongo/crypto/fle_crypto.h" #include "mongo/crypto/fle_tags.h" #include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/query/fle/encrypted_predicate.h" namespace mongo::fle { @@ -42,6 +43,9 @@ namespace mongo::fle { REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, + RangePredicate, + gFeatureFlagFLE2Range); std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload); @@ -62,7 +66,9 @@ std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( MatchExpression* expr) const { - invariant(expr->matchType() == MatchExpression::BETWEEN); + tassert(6720900, + "Range rewrite should only be called with $between operator.", + expr->matchType() == MatchExpression::BETWEEN); auto betExpr = static_cast<BetweenMatchExpression*>(expr); auto payload = betExpr->rhs(); @@ -72,9 +78,24 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( return makeTagDisjunction(toBSONArray(generateTags(payload))); } -// TODO: SERVER-67209 Server-side rewrite for agg expressions with $between. std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { - return nullptr; + auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr); + tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr); + auto children = betweenExpr->getChildren(); + uassert(6720902, "$between should have two children.", children.size() == 2); + + auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); + uassert(6720903, "first argument should be a fieldpath", fieldpath); + auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); + uassert(6720904, "second argument should be a constant", secondArg); + auto payload = secondArg->getValue(); + + if (!isPayload(payload)) { + return nullptr; + } + auto tags = toValues(generateTags(std::ref(payload))); + + return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); } // TODO: SERVER-67267 Rewrite $between to $_internalFleBetween when number of tags exceeds diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp index bc8bca184f3..17f33b55e52 100644 --- a/src/mongo/db/query/fle/range_predicate_test.cpp +++ b/src/mongo/db/query/fle/range_predicate_test.cpp @@ -29,6 +29,9 @@ #include "mongo/crypto/fle_crypto.h" #include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/expression_context_for_test.h" +#include "mongo/db/query/fle/encrypted_predicate.h" #include "mongo/db/query/fle/encrypted_predicate_test_fixtures.h" #include "mongo/db/query/fle/range_predicate.h" #include "mongo/idl/server_parameter_test_util.h" @@ -51,39 +54,52 @@ public: } + bool payloadValid = true; + protected: bool isPayload(const BSONElement& elt) const override { - return true; + return payloadValid; } bool isPayload(const Value& v) const override { - return true; + return payloadValid; } std::vector<PrfBlock> generateTags(BSONValue payload) const { return stdx::visit( - OverloadedVisitor{ - [&](BSONElement p) { - auto parsedPayload = p.Obj().firstElement(); - auto fieldName = parsedPayload.fieldNameStringData(); - - std::vector<BSONElement> range; - auto payloadAsArray = parsedPayload.Array(); - for (auto&& elt : payloadAsArray) { - range.push_back(elt); - } - - std::vector<PrfBlock> allTags; - for (auto i = range[0].Number(); i <= range[1].Number(); i++) { - ASSERT(_tags.find({fieldName, i}) != _tags.end()); - auto temp = _tags.find({fieldName, i})->second; - for (auto tag : temp) { - allTags.push_back(tag); - } - } - return allTags; - }, - [&](std::reference_wrapper<Value> v) { return std::vector<PrfBlock>{}; }}, + OverloadedVisitor{[&](BSONElement p) { + auto parsedPayload = p.Obj().firstElement(); + auto fieldName = parsedPayload.fieldNameStringData(); + + std::vector<BSONElement> range; + auto payloadAsArray = parsedPayload.Array(); + for (auto&& elt : payloadAsArray) { + range.push_back(elt); + } + + std::vector<PrfBlock> allTags; + for (auto i = range[0].Number(); i <= range[1].Number(); i++) { + ASSERT(_tags.find({fieldName, i}) != _tags.end()); + auto temp = _tags.find({fieldName, i})->second; + for (auto tag : temp) { + allTags.push_back(tag); + } + } + return allTags; + }, + [&](std::reference_wrapper<Value> v) { + if (v.get().isArray()) { + auto arr = v.get().getArray(); + std::vector<PrfBlock> allTags; + for (auto& val : arr) { + allTags.push_back(PrfBlock( + {static_cast<unsigned char>(val.coerceToInt())})); + } + return allTags; + } else { + return std::vector<PrfBlock>{}; + } + }}, payload); } @@ -99,7 +115,7 @@ protected: MockRangePredicate _predicate; }; -TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) { +TEST_F(RangePredicateRewriteTest, MatchRangeRewrite) { RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); int start = 1; @@ -126,5 +142,31 @@ TEST_F(RangePredicateRewriteTest, BasicRangeRewrite) { assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags))); } + +TEST_F(RangePredicateRewriteTest, AggRangeRewrite) { + auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}})); + + auto actual = _predicate.rewrite(inputExpr.get()); + + ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(), + expected->serialize(false).getDocument().toBson()); +} + +TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) { + auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})"); + auto inputExpr = + ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState); + + auto expected = inputExpr; + + _predicate.payloadValid = false; + auto actual = _predicate.rewrite(inputExpr.get()); + ASSERT(actual == nullptr); +} + }; // namespace } // namespace mongo::fle |