summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/fle/range_predicate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/query/fle/range_predicate.cpp')
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp27
1 files changed, 24 insertions, 3 deletions
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index 0ebf33790fc..184f4f4baa2 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -35,6 +35,7 @@
#include "mongo/crypto/fle_crypto.h"
#include "mongo/crypto/fle_tags.h"
#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/fle/encrypted_predicate.h"
namespace mongo::fle {
@@ -42,6 +43,9 @@ namespace mongo::fle {
REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN,
RangePredicate,
gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
@@ -62,7 +66,9 @@ std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
MatchExpression* expr) const {
- invariant(expr->matchType() == MatchExpression::BETWEEN);
+ tassert(6720900,
+ "Range rewrite should only be called with $between operator.",
+ expr->matchType() == MatchExpression::BETWEEN);
auto betExpr = static_cast<BetweenMatchExpression*>(expr);
auto payload = betExpr->rhs();
@@ -72,9 +78,24 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
return makeTagDisjunction(toBSONArray(generateTags(payload)));
}
-// TODO: SERVER-67209 Server-side rewrite for agg expressions with $between.
std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
- return nullptr;
+ auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr);
+ tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr);
+ auto children = betweenExpr->getChildren();
+ uassert(6720902, "$between should have two children.", children.size() == 2);
+
+ auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
+ uassert(6720903, "first argument should be a fieldpath", fieldpath);
+ auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
+ uassert(6720904, "second argument should be a constant", secondArg);
+ auto payload = secondArg->getValue();
+
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ auto tags = toValues(generateTags(std::ref(payload)));
+
+ return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
}
// TODO: SERVER-67267 Rewrite $between to $_internalFleBetween when number of tags exceeds