summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/fle/range_predicate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/query/fle/range_predicate.cpp')
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp128
1 files changed, 86 insertions, 42 deletions
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index 8c083b07dc1..dcd71777782 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
RangePredicate,
gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
+
+namespace {
+// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If
+// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range
+// predicate, then return null to the caller so that the rewrite can return null.
+std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) {
+ auto children = [&]() {
+ if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) {
+ return betweenExpr->getChildren();
+ } else {
+ auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr);
+ tassert(6720901,
+ "Range rewrite should only be called with $between or comparison operator.",
+ cmpExpr);
+ switch (cmpExpr->getOp()) {
+ case ExpressionCompare::GT:
+ case ExpressionCompare::GTE:
+ case ExpressionCompare::LT:
+ case ExpressionCompare::LTE:
+ return cmpExpr->getChildren();
+
+ case ExpressionCompare::EQ:
+ case ExpressionCompare::NE:
+ case ExpressionCompare::CMP:
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }
+ }
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }();
+ if (children.empty()) {
+ return {nullptr, Value()};
+ }
+ // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2.
+ auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
+ uassert(6720903, "first argument should be a fieldpath", fieldpath);
+ auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
+ uassert(6720904, "second argument should be a constant", secondArg);
+ auto payload = secondArg->getValue();
+ return {children[0], payload};
+}
+} // namespace
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ StringData path, ParsedFindRangePayload payload) const {
+ auto* expCtx = _rewriter->getExpressionContext();
+ return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ payload);
+}
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
+ tassert(7030501,
+ "$internalFleBetween can only be generated from a non-stub payload.",
+ !payload.isStub());
+ auto cm = payload.maxCounter;
+ ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
+ std::vector<ConstDataRange> edcTokens;
+ std::transform(std::make_move_iterator(payload.edges.value().begin()),
+ std::make_move_iterator(payload.edges.value().end()),
+ std::back_inserter(edcTokens),
+ [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
+
+ auto* expCtx = _rewriter->getExpressionContext();
+ return std::make_unique<ExpressionInternalFLEBetween>(
+ expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
+}
+
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
std::vector<PrfBlock> tags;
@@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
return makeTagDisjunction(toBSONArray(generateTags(payload)));
}
-std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) {
- auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr);
- tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr);
- auto children = betweenExpr->getChildren();
- uassert(6720902, "$between should have two children.", children.size() == 2);
-
- auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
- uassert(6720903, "first argument should be a fieldpath", fieldpath);
- auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
- uassert(6720904, "second argument should be a constant", secondArg);
- auto payload = secondArg->getValue();
- return {children[0], payload};
-}
-
std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
- auto [_, payload] = validateBetween(expr);
+ auto [fieldpath, payload] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(payload)) {
return nullptr;
}
+ if (isStub(std::ref(payload))) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
+
auto tags = toValues(generateTags(std::ref(payload)));
return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
}
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- StringData path, ParsedFindRangePayload payload) const {
- auto* expCtx = _rewriter->getExpressionContext();
- return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
- expCtx, path.toString(), expCtx->variablesParseState),
- payload);
-}
-
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
- tassert(7030501,
- "$internalFleBetween can only be generated from a non-stub payload.",
- !payload.isStub());
- auto cm = payload.maxCounter;
- ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
- std::vector<ConstDataRange> edcTokens;
- std::transform(std::make_move_iterator(payload.edges.value().begin()),
- std::make_move_iterator(payload.edges.value().end()),
- std::back_inserter(edcTokens),
- [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
-
- auto* expCtx = _rewriter->getExpressionContext();
- return std::make_unique<ExpressionInternalFLEBetween>(
- expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
-}
-
std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
MatchExpression* expr) const {
BSONElement ffp;
@@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
}
std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const {
- auto [fieldpath, ffp] = validateBetween(expr);
+ auto [fieldpath, ffp] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(ffp)) {
return nullptr;
}
auto payload = parseFindPayload<ParsedFindRangePayload>(ffp);
+ if (payload.isStub()) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
return fleBetweenFromPayload(fieldpath, payload);
}
} // namespace mongo::fle