summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavis Haupt <davis.haupt@mongodb.com>2022-11-01 19:35:57 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-11-01 20:45:01 +0000
commit464abab59fa63b45d187a4145dceec3281bbcf0d (patch)
treed059d053023fa0c95f301a63a87d03ac26d2c550
parent5d04b5969c880cc7b7acba50490dc762c0964f97 (diff)
downloadmongo-464abab59fa63b45d187a4145dceec3281bbcf0d.tar.gz
SERVER-70306 support aggregate comparison operators in server-side encrypted range rewrite
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate.h67
-rw-r--r--src/mongo/db/query/fle/query_rewriter.cpp29
-rw-r--r--src/mongo/db/query/fle/query_rewriter_test.cpp192
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp128
-rw-r--r--src/mongo/db/query/fle/range_predicate.h6
-rw-r--r--src/mongo/db/query/fle/range_predicate_test.cpp154
6 files changed, 469 insertions, 107 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate.h b/src/mongo/db/query/fle/encrypted_predicate.h
index 5fcac966dbf..fc78aa28014 100644
--- a/src/mongo/db/query/fle/encrypted_predicate.h
+++ b/src/mongo/db/query/fle/encrypted_predicate.h
@@ -40,6 +40,7 @@
#include "mongo/db/matcher/expression_leaf.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/query/fle/query_rewriter_interface.h"
+#include "mongo/stdx/unordered_map.h"
/**
* This file contains an abstract class that describes rewrites on agg Expressions and
@@ -194,33 +195,37 @@ private:
* are keyed on the dynamic type for the Expression subclass.
*/
-using ExpressionToRewriteMap = stdx::unordered_map<
- std::type_index,
- std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>>;
+using ExpressionRewriteFunction =
+ std::function<std::unique_ptr<Expression>(QueryRewriterInterface*, Expression*)>;
+using ExpressionToRewriteMap =
+ stdx::unordered_map<std::type_index, std::vector<ExpressionRewriteFunction>>;
extern ExpressionToRewriteMap aggPredicateRewriteMap;
-using MatchTypeToRewriteMap = stdx::unordered_map<
- MatchExpression::MatchType,
- std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>>;
+using MatchRewriteFunction =
+ std::function<std::unique_ptr<MatchExpression>(QueryRewriterInterface*, MatchExpression*)>;
+using MatchTypeToRewriteMap =
+ stdx::unordered_map<MatchExpression::MatchType, std::vector<MatchRewriteFunction>>;
extern MatchTypeToRewriteMap matchPredicateRewriteMap;
/**
* Register an agg rewrite if a condition is true at startup time.
*/
-#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \
- MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className)(InitializerContext*) { \
- \
- invariant(aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()); \
- aggPredicateRewriteMap[typeid(className)] = [&](auto* rewriter, auto* expr) { \
- if (isEnabledExpr) { \
- return rewriteClass{rewriter}.rewrite(expr); \
- } else { \
- return std::unique_ptr<Expression>(nullptr); \
- } \
- }; \
- }
+#define REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_GUARDED(className, rewriteClass, isEnabledExpr) \
+ MONGO_INITIALIZER(encryptedAggPredicateRewriteFor_##className##_##rewriteClass) \
+ (InitializerContext*) { \
+ if (aggPredicateRewriteMap.find(typeid(className)) == aggPredicateRewriteMap.end()) { \
+ aggPredicateRewriteMap[typeid(className)] = std::vector<ExpressionRewriteFunction>(); \
+ } \
+ aggPredicateRewriteMap[typeid(className)].push_back([](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<Expression>(nullptr); \
+ } \
+ }); \
+ };
/**
* Register an agg rewrite unconditionally.
@@ -239,17 +244,21 @@ extern MatchTypeToRewriteMap matchPredicateRewriteMap;
* Register a MatchExpression rewrite if a condition is true at startup time.
*/
#define REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_GUARDED(matchType, rewriteClass, isEnabledExpr) \
- MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType)(InitializerContext*) { \
- \
- invariant(matchPredicateRewriteMap.find(MatchExpression::matchType) == \
- matchPredicateRewriteMap.end()); \
- matchPredicateRewriteMap[MatchExpression::matchType] = [&](auto* rewriter, auto* expr) { \
- if (isEnabledExpr) { \
- return rewriteClass{rewriter}.rewrite(expr); \
- } else { \
- return std::unique_ptr<MatchExpression>(nullptr); \
- } \
- }; \
+ MONGO_INITIALIZER(encryptedMatchPredicateRewriteFor_##matchType##_##rewriteClass) \
+ (InitializerContext*) { \
+ if (matchPredicateRewriteMap.find(MatchExpression::matchType) == \
+ matchPredicateRewriteMap.end()) { \
+ matchPredicateRewriteMap[MatchExpression::matchType] = \
+ std::vector<MatchRewriteFunction>(); \
+ } \
+ matchPredicateRewriteMap[MatchExpression::matchType].push_back( \
+ [](auto* rewriter, auto* expr) { \
+ if (isEnabledExpr) { \
+ return rewriteClass{rewriter}.rewrite(expr); \
+ } else { \
+ return std::unique_ptr<MatchExpression>(nullptr); \
+ } \
+ }); \
};
/**
* Register a MatchExpression rewrite unconditionally.
diff --git a/src/mongo/db/query/fle/query_rewriter.cpp b/src/mongo/db/query/fle/query_rewriter.cpp
index 441f436ec00..da4fb05ecc7 100644
--- a/src/mongo/db/query/fle/query_rewriter.cpp
+++ b/src/mongo/db/query/fle/query_rewriter.cpp
@@ -40,12 +40,15 @@ public:
: queryRewriter(queryRewriter), exprRewrites(exprRewrites){};
std::unique_ptr<Expression> postVisit(Expression* exp) {
- if (auto rewrite = exprRewrites.find(typeid(*exp)); rewrite != exprRewrites.end()) {
- auto expr = rewrite->second(queryRewriter, exp);
- if (expr != nullptr) {
- didRewrite = true;
+ if (auto rewriteEntry = exprRewrites.find(typeid(*exp));
+ rewriteEntry != exprRewrites.end()) {
+ for (auto& rewrite : rewriteEntry->second) {
+ auto expr = rewrite(queryRewriter, exp);
+ if (expr != nullptr) {
+ didRewrite = true;
+ return expr;
+ }
}
- return expr;
}
return nullptr;
}
@@ -109,13 +112,17 @@ std::unique_ptr<MatchExpression> QueryRewriter::_rewrite(MatchExpression* expr)
return nullptr;
}
default: {
- if (auto rewrite = _matchRewrites.find(expr->matchType());
- rewrite != _matchRewrites.end()) {
- auto rewritten = rewrite->second(this, expr);
- if (rewritten != nullptr) {
- _rewroteLastExpression = true;
+ if (auto rewriteEntry = _matchRewrites.find(expr->matchType());
+ rewriteEntry != _matchRewrites.end()) {
+ for (auto& rewrite : rewriteEntry->second) {
+ auto rewritten = rewrite(this, expr);
+ // Only one rewrite can be applied to an expression, so return as soon as a
+ // rewrite returns something other than nullptr.
+ if (rewritten != nullptr) {
+ _rewroteLastExpression = true;
+ return rewritten;
+ }
}
- return rewritten;
}
return nullptr;
}
diff --git a/src/mongo/db/query/fle/query_rewriter_test.cpp b/src/mongo/db/query/fle/query_rewriter_test.cpp
index d779f6be1cc..057bae0dd2b 100644
--- a/src/mongo/db/query/fle/query_rewriter_test.cpp
+++ b/src/mongo/db/query/fle/query_rewriter_test.cpp
@@ -71,7 +71,7 @@ protected:
if (!elt.isABSONObj()) {
return false;
}
- return elt.Obj().firstElementFieldNameStringData() == "encrypt"_sd;
+ return elt.Obj().hasField("encrypt"_sd);
}
bool isPayload(const Value& v) const override {
if (!v.isObject()) {
@@ -97,9 +97,110 @@ protected:
};
std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override {
+ auto eqMatch = dynamic_cast<ExpressionCompare*>(expr);
+ invariant(eqMatch);
+ // Only operate over equality comparisons.
+ if (eqMatch->getOp() != ExpressionCompare::EQ) {
+ return nullptr;
+ }
+ auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get());
+ // If the comparison doesn't hold a constant, then don't rewrite.
+ if (!payload) {
+ return nullptr;
+ }
+
+ // If the constant is not considered a payload, then don't rewrite.
+ if (!isPayload(payload->getValue())) {
+ return nullptr;
+ }
+ auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(),
+ ExpressionCompare::GT);
+ cmp->addOperand(eqMatch->getOperandList()[0]);
+ cmp->addOperand(
+ ExpressionConstant::create(eqMatch->getExpressionContext(),
+ payload->getValue().getDocument().getField("encrypt")));
+ return cmp;
+ }
+
+ std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
+ MatchExpression* expr) const override {
+ return nullptr;
+ }
+
+ std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override {
return nullptr;
}
+private:
+ // This method is not used in mock implementations of the EncryptedPredicate since isPayload(),
+ // which normally calls encryptedBinDataType(), is overridden to look for plain objects rather
+ // than BinData. Since this method is pure virtual on the superclass and needs to be
+ // implemented, it is set to kPlaceholder (0).
+ EncryptedBinDataType encryptedBinDataType() const override {
+ return EncryptedBinDataType::kPlaceholder;
+ }
+};
+
+// A second mock rewrite which replaces documents with the key "foo" into $lt operations. We need
+// two different rewrites that are registered on the same operator to verify that all rewrites are
+// iterated through.
+class OtherMockPredicateRewriter : public fle::EncryptedPredicate {
+public:
+ OtherMockPredicateRewriter(const fle::QueryRewriterInterface* rewriter)
+ : EncryptedPredicate(rewriter) {}
+
+protected:
+ bool isPayload(const BSONElement& elt) const override {
+ if (!elt.isABSONObj()) {
+ return false;
+ }
+ return elt.Obj().hasField("foo"_sd);
+ }
+ bool isPayload(const Value& v) const override {
+ if (!v.isObject()) {
+ return false;
+ }
+ return !v.getDocument().getField("foo").missing();
+ }
+
+ std::vector<PrfBlock> generateTags(fle::BSONValue payload) const override {
+ return {};
+ };
+
+ // Encrypted values will be rewritten from $eq to $lt. This is an arbitrary decision just to
+ // make sure that the rewrite works properly.
+ std::unique_ptr<MatchExpression> rewriteToTagDisjunction(MatchExpression* expr) const override {
+ invariant(expr->matchType() == MatchExpression::EQ);
+ auto eqMatch = static_cast<EqualityMatchExpression*>(expr);
+ if (!isPayload(eqMatch->getData())) {
+ return nullptr;
+ }
+ return std::make_unique<LTMatchExpression>(eqMatch->path(),
+ eqMatch->getData().Obj().firstElement());
+ };
+
+ std::unique_ptr<Expression> rewriteToTagDisjunction(Expression* expr) const override {
+ auto eqMatch = dynamic_cast<ExpressionCompare*>(expr);
+ invariant(eqMatch);
+ if (eqMatch->getOp() != ExpressionCompare::EQ) {
+ return nullptr;
+ }
+ auto payload = dynamic_cast<ExpressionConstant*>(eqMatch->getOperandList()[1].get());
+ if (!payload) {
+ return nullptr;
+ }
+
+ if (!isPayload(payload->getValue())) {
+ return nullptr;
+ }
+ auto cmp = std::make_unique<ExpressionCompare>(eqMatch->getExpressionContext(),
+ ExpressionCompare::LT);
+ cmp->addOperand(eqMatch->getOperandList()[0]);
+ cmp->addOperand(ExpressionConstant::create(
+ eqMatch->getExpressionContext(), payload->getValue().getDocument().getField("foo")));
+ return cmp;
+ }
+
std::unique_ptr<MatchExpression> rewriteToRuntimeComparison(
MatchExpression* expr) const override {
return nullptr;
@@ -111,7 +212,7 @@ protected:
private:
EncryptedBinDataType encryptedBinDataType() const override {
- return EncryptedBinDataType::kPlaceholder; // return the 0 type. this isn't used anywhere.
+ return EncryptedBinDataType::kPlaceholder;
}
};
@@ -119,8 +220,17 @@ void setMockRewriteMaps(fle::MatchTypeToRewriteMap& match,
fle::ExpressionToRewriteMap& agg,
fle::TagMap& tags,
std::set<StringData>& encryptedFields) {
- match[MatchExpression::EQ] = [&](auto* rewriter, auto* expr) {
- return MockPredicateRewriter{rewriter}.rewrite(expr);
+ match[MatchExpression::EQ] = {
+ [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); },
+ [&](auto* rewriter, auto* expr) {
+ return OtherMockPredicateRewriter{rewriter}.rewrite(expr);
+ },
+ };
+ agg[typeid(ExpressionCompare)] = {
+ [&](auto* rewriter, auto* expr) { return MockPredicateRewriter{rewriter}.rewrite(expr); },
+ [&](auto* rewriter, auto* expr) {
+ return OtherMockPredicateRewriter{rewriter}.rewrite(expr);
+ },
};
}
@@ -137,9 +247,17 @@ public:
return res ? res.value() : obj;
}
+ BSONObj rewriteAggExpressionForTest(const BSONObj& obj) {
+ auto expr = Expression::parseExpression(&_expCtx, obj, _expCtx.variablesParseState);
+ auto result = rewriteExpression(expr.get());
+ return result ? result->serialize(false).getDocument().toBson()
+ : expr->serialize(false).getDocument().toBson();
+ }
+
private:
fle::TagMap _tags;
std::set<StringData> _encryptedFields;
+ ExpressionContextForTest _expCtx;
};
class FLEServerRewriteTest : public unittest::Test {
@@ -167,22 +285,53 @@ protected:
ASSERT_MATCH_EXPRESSION_REWRITE(input, expected); \
}
+#define ASSERT_AGG_EXPRESSION_REWRITE(input, expected) \
+ auto actual = _mock->rewriteAggExpressionForTest(fromjson(input)); \
+ ASSERT_BSONOBJ_EQ(actual, fromjson(expected));
+
+#define TEST_FLE_REWRITE_AGG(name, input, expected) \
+ TEST_F(FLEServerRewriteTest, name##_AggExpression) { \
+ ASSERT_AGG_EXPRESSION_REWRITE(input, expected); \
+ }
+
TEST_FLE_REWRITE_MATCH(TopLevel_DottedPath,
"{'user.ssn': {$eq: {encrypt: 2}}}",
"{'user.ssn': {$gt: 2}}");
+TEST_FLE_REWRITE_AGG(TopLevel_DottedPath,
+ "{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}",
+ "{$gt: ['$user.ssn', {$const: 2}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_BothEncrypted,
"{$and: [{ssn: {encrypt: 2}}, {age: {encrypt: 4}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$gt: 4}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_BothEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {encrypt: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$gt: ['$age', {$const: 4}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncrypted,
"{$and: [{ssn: {encrypt: 2}}, {age: {plain: 4}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$eq: {plain: 4}}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_PartlyEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$eq: ['$age', {$const: {plain: 4}}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator,
"{$and: [{ssn: {encrypt: 2}}, {age: {$lt: {encrypt: 4}}}]}",
"{$and: [{ssn: {$gt: 2}}, {age: {$lt: {encrypt: 4}}}]}");
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Conjunction_PartlyEncryptedWithUnregisteredOperator,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$lt: ['$age', {$const: {encrypt: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: {encrypt: "
+ "4}}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Encrypted_Nested_Unecrypted,
"{$and: [{ssn: {encrypt: 2}}, {user: {region: 'US'}}]}",
"{$and: [{ssn: {$gt: 2}}, {user: {$eq: {region: 'US'}}}]}");
@@ -191,6 +340,10 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Not,
"{ssn: {$not: {$eq: {encrypt: 5}}}}",
"{ssn: {$not: {$gt: 5}}}");
+TEST_FLE_REWRITE_AGG(TopLevel_Not,
+ "{$not: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}]}",
+ "{$not: [{$gt: ['$ssn', {$const: 2}]}]}")
+
TEST_FLE_REWRITE_MATCH(TopLevel_Neq, "{ssn: {$ne: {encrypt: 5}}}", "{ssn: {$not: {$gt: 5}}}}");
TEST_FLE_REWRITE_MATCH(
@@ -198,6 +351,12 @@ TEST_FLE_REWRITE_MATCH(
"{$and: [{$and: [{ssn: {encrypt: 2}}, {other: 'field'}]}, {otherSsn: {encrypt: 3}}]}",
"{$and: [{$and: [{ssn: {$gt: 2}}, {other: {$eq: 'field'}}]}, {otherSsn: {$gt: 3}}]}");
+TEST_FLE_REWRITE_AGG(NestedConjunction,
+ "{$and: [{$and: [{$eq: ['$ssn', {$const: {encrypt: 2}}]},{$eq: ['$other', "
+ "'field']}]},{$eq: ['$age',{$const: {encrypt: 4}}]}]}",
+ "{$and: [{$and: [{$gt: ['$ssn', {$const: 2}]},{$eq: ['$other', "
+ "{$const: 'field'}]}]},{$gt: ['$age',{$const: 4}]}]}");
+
TEST_FLE_REWRITE_MATCH(TopLevel_Nor,
"{$nor: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
"{$nor: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
@@ -205,5 +364,30 @@ TEST_FLE_REWRITE_MATCH(TopLevel_Nor,
TEST_FLE_REWRITE_MATCH(TopLevel_Or,
"{$or: [{ssn: {encrypt: 5}}, {other: {$eq: 'field'}}]}",
"{$or: [{ssn: {$gt: 5}}, {other: {$eq: 'field'}}]}");
+
+TEST_FLE_REWRITE_AGG(
+ TopLevel_Or,
+ "{$or: [{$eq: ['$ssn', {$const: {encrypt: 2}}]}, {$eq: ['$ssn', {$const: {encrypt: 4}}]}]}",
+ "{$or: [{$gt: ['$ssn', {$const: 2}]}, {$gt: ['$ssn', {$const: 4}]}]}")
+
+
+// Test that the rewriter will work from any rewrite registered to an expression. The test rewriter
+// has two rewrites registered on $eq.
+
+TEST_FLE_REWRITE_MATCH(OtherRewrite_Basic, "{'ssn': {$eq: {foo: 2}}}", "{'ssn': {$lt: 2}}");
+
+TEST_FLE_REWRITE_AGG(OtherRewrite_Basic,
+ "{$eq: ['$user.ssn', {$const: {foo: 2}}]}",
+ "{$lt: ['$user.ssn', {$const: 2}]}");
+
+TEST_FLE_REWRITE_MATCH(OtherRewrite_Conjunction_BothEncrypted,
+ "{$and: [{ssn: {encrypt: 2}}, {age: {foo: 4}}]}",
+ "{$and: [{ssn: {$gt: 2}}, {age: {$lt: 4}}]}");
+
+TEST_FLE_REWRITE_AGG(
+ OtherRewrite_Conjunction_BothEncrypted,
+ "{$and: [{$eq: ['$user.ssn', {$const: {encrypt: 2}}]}, {$eq: ['$age', {$const: {foo: "
+ "4}}]}]}",
+ "{$and: [{$gt: ['$user.ssn', {$const: 2}]}, {$lt: ['$age', {$const: 4}]}]}");
} // namespace
} // namespace mongo
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index 8c083b07dc1..dcd71777782 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -55,6 +55,77 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
RangePredicate,
gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionCompare,
+ RangePredicate,
+ gFeatureFlagFLE2Range);
+
+namespace {
+// Validate the range operator passed in and return the fieldpath and payload for the rewrite. If
+// the passed-in expression is a comparison with $eq, $ne, or $cmp, none of which represent a range
+// predicate, then return null to the caller so that the rewrite can return null.
+std::pair<boost::intrusive_ptr<Expression>, Value> validateRangeOp(Expression* expr) {
+ auto children = [&]() {
+ if (auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr)) {
+ return betweenExpr->getChildren();
+ } else {
+ auto cmpExpr = dynamic_cast<ExpressionCompare*>(expr);
+ tassert(6720901,
+ "Range rewrite should only be called with $between or comparison operator.",
+ cmpExpr);
+ switch (cmpExpr->getOp()) {
+ case ExpressionCompare::GT:
+ case ExpressionCompare::GTE:
+ case ExpressionCompare::LT:
+ case ExpressionCompare::LTE:
+ return cmpExpr->getChildren();
+
+ case ExpressionCompare::EQ:
+ case ExpressionCompare::NE:
+ case ExpressionCompare::CMP:
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }
+ }
+ return std::vector<boost::intrusive_ptr<Expression>>();
+ }();
+ if (children.empty()) {
+ return {nullptr, Value()};
+ }
+ // Both ExpressionBetween and ExpressionCompare have a fixed arity of 2.
+ auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
+ uassert(6720903, "first argument should be a fieldpath", fieldpath);
+ auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
+ uassert(6720904, "second argument should be a constant", secondArg);
+ auto payload = secondArg->getValue();
+ return {children[0], payload};
+}
+} // namespace
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ StringData path, ParsedFindRangePayload payload) const {
+ auto* expCtx = _rewriter->getExpressionContext();
+ return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState),
+ payload);
+}
+
+std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
+ boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
+ tassert(7030501,
+ "$internalFleBetween can only be generated from a non-stub payload.",
+ !payload.isStub());
+ auto cm = payload.maxCounter;
+ ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
+ std::vector<ConstDataRange> edcTokens;
+ std::transform(std::make_move_iterator(payload.edges.value().begin()),
+ std::make_move_iterator(payload.edges.value().end()),
+ std::back_inserter(edcTokens),
+ [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
+
+ auto* expCtx = _rewriter->getExpressionContext();
+ return std::make_unique<ExpressionInternalFLEBetween>(
+ expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
+}
+
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
std::vector<PrfBlock> tags;
@@ -99,56 +170,23 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
return makeTagDisjunction(toBSONArray(generateTags(payload)));
}
-std::pair<boost::intrusive_ptr<Expression>, Value> validateBetween(Expression* expr) {
- auto betweenExpr = dynamic_cast<ExpressionBetween*>(expr);
- tassert(6720901, "Range rewrite should only be called with $between operator.", betweenExpr);
- auto children = betweenExpr->getChildren();
- uassert(6720902, "$between should have two children.", children.size() == 2);
-
- auto fieldpath = dynamic_cast<ExpressionFieldPath*>(children[0].get());
- uassert(6720903, "first argument should be a fieldpath", fieldpath);
- auto secondArg = dynamic_cast<ExpressionConstant*>(children[1].get());
- uassert(6720904, "second argument should be a constant", secondArg);
- auto payload = secondArg->getValue();
- return {children[0], payload};
-}
-
std::unique_ptr<Expression> RangePredicate::rewriteToTagDisjunction(Expression* expr) const {
- auto [_, payload] = validateBetween(expr);
+ auto [fieldpath, payload] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(payload)) {
return nullptr;
}
+ if (isStub(std::ref(payload))) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
+
auto tags = toValues(generateTags(std::ref(payload)));
return makeTagDisjunction(_rewriter->getExpressionContext(), std::move(tags));
}
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- StringData path, ParsedFindRangePayload payload) const {
- auto* expCtx = _rewriter->getExpressionContext();
- return fleBetweenFromPayload(ExpressionFieldPath::createPathFromString(
- expCtx, path.toString(), expCtx->variablesParseState),
- payload);
-}
-
-std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
- boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
- tassert(7030501,
- "$internalFleBetween can only be generated from a non-stub payload.",
- !payload.isStub());
- auto cm = payload.maxCounter;
- ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
- std::vector<ConstDataRange> edcTokens;
- std::transform(std::make_move_iterator(payload.edges.value().begin()),
- std::make_move_iterator(payload.edges.value().end()),
- std::back_inserter(edcTokens),
- [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
-
- auto* expCtx = _rewriter->getExpressionContext();
- return std::make_unique<ExpressionInternalFLEBetween>(
- expCtx, fieldpath, serverToken.toCDR(), cm, std::move(edcTokens));
-}
-
std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
MatchExpression* expr) const {
BSONElement ffp;
@@ -179,11 +217,17 @@ std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
}
std::unique_ptr<Expression> RangePredicate::rewriteToRuntimeComparison(Expression* expr) const {
- auto [fieldpath, ffp] = validateBetween(expr);
+ auto [fieldpath, ffp] = validateRangeOp(expr);
+ if (!fieldpath) {
+ return nullptr;
+ }
if (!isPayload(ffp)) {
return nullptr;
}
auto payload = parseFindPayload<ParsedFindRangePayload>(ffp);
+ if (payload.isStub()) {
+ return std::make_unique<ExpressionConstant>(_rewriter->getExpressionContext(), Value(true));
+ }
return fleBetweenFromPayload(fieldpath, payload);
}
} // namespace mongo::fle
diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h
index bc2a47fa173..5fbd7484ad0 100644
--- a/src/mongo/db/query/fle/range_predicate.h
+++ b/src/mongo/db/query/fle/range_predicate.h
@@ -55,6 +55,12 @@ protected:
return parsedPayload.isStub();
}
+ virtual bool isStub(Value elt) const {
+ auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(elt);
+ return parsedPayload.isStub();
+ }
+
+
private:
EncryptedBinDataType encryptedBinDataType() const override {
return EncryptedBinDataType::kFLE2FindRangePayload;
diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp
index f81723da5eb..61110c19041 100644
--- a/src/mongo/db/query/fle/range_predicate_test.cpp
+++ b/src/mongo/db/query/fle/range_predicate_test.cpp
@@ -68,7 +68,11 @@ protected:
return isStubPayload;
}
- std::vector<PrfBlock> generateTags(BSONValue payload) const {
+ bool isStub(Value elt) const override {
+ return isStubPayload;
+ }
+
+ std::vector<PrfBlock> generateTags(BSONValue payload) const override {
return stdx::visit(
OverloadedVisitor{[&](BSONElement p) {
if (p.isABSONObj()) {
@@ -126,8 +130,6 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) {
TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) {
RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
- std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
-
auto expCtx = make_intrusive<ExpressionContextForTest>();
std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"};
@@ -159,29 +161,98 @@ TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) {
}
}
+TEST_F(RangePredicateRewriteTest, AggRangeRewrite_Stub) {
+ RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
+
+ {
+ auto input = fromjson(str::stream() << "{$between: [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = ExpressionConstant::create(&_expCtx, Value(true));
+
+ _predicate.isStubPayload = true;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual);
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = ExpressionConstant::create(&_expCtx, Value(true));
+
+ _predicate.isStubPayload = true;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual);
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+}
+
TEST_F(RangePredicateRewriteTest, AggRangeRewrite) {
- auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
- auto inputExpr =
- ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+ {
+ auto op = "$between";
+ auto input = fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
- auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
+ auto actual = _predicate.rewrite(inputExpr.get());
- auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+ {
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input =
+ fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = makeTagDisjunction(&_expCtx, toValues({{1}, {2}, {3}}));
- ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
- expected->serialize(false).getDocument().toBson());
+ auto actual = _predicate.rewrite(inputExpr.get());
+
+ ASSERT_BSONOBJ_EQ(actual->serialize(false).getDocument().toBson(),
+ expected->serialize(false).getDocument().toBson());
+ }
+ }
}
TEST_F(RangePredicateRewriteTest, AggRangeRewriteNoOp) {
- auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
- auto inputExpr =
- ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+ {
+ auto input = fromjson(R"({$between: ["$age", {$literal: [1, 2, 3]}]})");
+ auto inputExpr =
+ ExpressionBetween::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
- auto expected = inputExpr;
+ auto expected = inputExpr;
- _predicate.payloadValid = false;
- auto actual = _predicate.rewrite(inputExpr.get());
- ASSERT(actual == nullptr);
+ _predicate.payloadValid = false;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual == nullptr);
+ }
+ {
+ auto ops = {"$gt", "$lt", "$gte", "$lte"};
+ for (auto& op : ops) {
+ auto input =
+ fromjson(str::stream() << "{" << op << ": [\"$age\", {$literal: [1, 2, 3]}]}");
+ auto inputExpr =
+ ExpressionCompare::parseExpression(&_expCtx, input, _expCtx.variablesParseState);
+
+ auto expected = inputExpr;
+
+ _predicate.payloadValid = false;
+ auto actual = _predicate.rewrite(inputExpr.get());
+ ASSERT(actual == nullptr);
+ }
+ }
}
BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) {
@@ -219,6 +290,17 @@ std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx,
return std::make_unique<ExpressionBetween>(expCtx, std::move(children));
}
+std::unique_ptr<Expression> generateBetweenWithFFP(
+ ExpressionContext* expCtx, ExpressionCompare::CmpOp op, StringData path, int lb, int ub) {
+ auto ffp = Value(generateFFP(path, lb, ub, 0, 255).firstElement());
+ auto ffpExpr = make_intrusive<ExpressionConstant>(expCtx, ffp);
+ auto fieldpath = ExpressionFieldPath::createPathFromString(
+ expCtx, path.toString(), expCtx->variablesParseState);
+ std::vector<boost::intrusive_ptr<Expression>> children = {std::move(fieldpath),
+ std::move(ffpExpr)};
+ return std::make_unique<ExpressionCompare>(expCtx, op, std::move(children));
+}
+
TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
_mock.setForceEncryptedCollScanForTest();
auto expected = fromjson(R"({
@@ -273,9 +355,6 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) {
_mock.setForceEncryptedCollScanForTest();
- auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35);
- auto result = _predicate.rewrite(input.get());
- ASSERT(result);
auto expected = fromjson(R"({
"$_internalFleBetween": {
"field": "$age",
@@ -309,7 +388,40 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) {
}
}
})");
- ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ {
+ auto input = generateBetweenWithFFP(&_expCtx, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result);
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ }
+ {
+ auto ops = {ExpressionCompare::GT,
+ ExpressionCompare::GTE,
+ ExpressionCompare::LT,
+ ExpressionCompare::LTE};
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result);
+ ASSERT_BSONOBJ_EQ(result->serialize(false).getDocument().toBson(), expected);
+ }
+ }
+}
+
+
+TEST_F(RangePredicateRewriteTest, UnsupportedComparisonOps) {
+ auto ops = {ExpressionCompare::CMP, ExpressionCompare::EQ, ExpressionCompare::NE};
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ }
+ _mock.setForceEncryptedCollScanForTest();
+ for (auto& op : ops) {
+ auto input = generateBetweenWithFFP(&_expCtx, op, "age", 23, 35);
+ auto result = _predicate.rewrite(input.get());
+ ASSERT(result == nullptr);
+ }
}
}; // namespace