diff options
Diffstat (limited to 'src/mongo/db/query/fle/range_predicate.cpp')
-rw-r--r-- | src/mongo/db/query/fle/range_predicate.cpp | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index 0ebf33790fc..184f4f4baa2 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -35,6 +35,7 @@ #include "mongo/crypto/fle_crypto.h" #include "mongo/crypto/fle_tags.h" #include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/query/fle/encrypted_predicate.h" namespace mongo::fle { @@ -42,6 +43,9 @@ namespace mongo::fle { REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, + RangePredicate, + gFeatureFlagFLE2Range); std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload); @@ -62,7 +66,9 @@ std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const { std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( MatchExpression* expr) const { - invariant(expr->matchType() == MatchExpression::BETWEEN); + tassert(6720900, + "Range rewrite should only be called with $between operator.", + expr->matchType() == MatchExpression::BETWEEN); auto betExpr = static_cast<BetweenMatchExpression*>(expr); auto payload = betExpr->rhs(); @@ -72,9 +78,24 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction( return makeTagDisjunction(toBSONArray(generateTags(payload))); } -// TODO: SERVER-67209 Server-side rewrite for agg expressions with $between. std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const { - return nullptr; + auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr); + tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr); + auto children = betweenExpr->getChildren(); + uassert(6720902, "$between should have two children.", children.size() == 2); + + auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get()); + uassert(6720903, "first argument should be a fieldpath", fieldpath); + auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get()); + uassert(6720904, "second argument should be a constant", secondArg); + auto payload = secondArg->getValue(); + + if (!isPayload(payload)) { + return nullptr; + } + auto tags = toValues(generateTags(std::ref(payload))); + + return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags)); } // TODO: SERVER-67267 Rewrite $between to $_internalFleBetween when number of tags exceeds |