summaryrefslogtreecommitdiff
path: root/src/mongo/db
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db')
-rw-r--r--src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h8
-rw-r--r--src/mongo/db/query/fle/range_predicate.cpp46
-rw-r--r--src/mongo/db/query/fle/range_predicate.h5
-rw-r--r--src/mongo/db/query/fle/range_predicate_test.cpp133
4 files changed, 134 insertions, 58 deletions
diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
index dfc6aa2b9b0..d521b6f9989 100644
--- a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
+++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h
@@ -100,6 +100,14 @@ public:
static_cast<MatchExpression*>(expected.get())->serialize());
}
+ template <typename T>
+ void assertRewriteForOp(const EncryptedPredicate& pred,
+ BSONElement rhs,
+ std::vector<PrfBlock> allTags) {
+ auto inputExpr = T("age", rhs);
+ assertRewriteToTags(pred, &inputExpr, toBSONArray(std::move(allTags)));
+ }
+
protected:
MockServerRewrite _mock{};
ExpressionContextForTest _expCtx;
diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp
index d4bfd1f6bc2..8c083b07dc1 100644
--- a/src/mongo/db/query/fle/range_predicate.cpp
+++ b/src/mongo/db/query/fle/range_predicate.cpp
@@ -34,6 +34,7 @@
#include "mongo/crypto/encryption_fields_gen.h"
#include "mongo/crypto/fle_crypto.h"
#include "mongo/crypto/fle_tags.h"
+#include "mongo/db/matcher/expression_always_boolean.h"
#include "mongo/db/matcher/expression_expr.h"
#include "mongo/db/matcher/expression_leaf.h"
#include "mongo/db/pipeline/expression.h"
@@ -44,6 +45,12 @@ namespace mongo::fle {
REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN,
RangePredicate,
gFeatureFlagFLE2Range);
+
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(GT, RangePredicate, gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(GTE, RangePredicate, gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(LT, RangePredicate, gFeatureFlagFLE2Range);
+REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(LTE, RangePredicate, gFeatureFlagFLE2Range);
+
REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
RangePredicate,
gFeatureFlagFLE2Range);
@@ -51,7 +58,8 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween,
std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(payload);
std::vector<PrfBlock> tags;
- for (auto& edge : parsedPayload.edges) {
+ tassert(7030500, "Must generate tags from a non-stub payload.", !parsedPayload.isStub());
+ for (auto& edge : parsedPayload.edges.value()) {
auto tagsForEdge = readTags(*_rewriter->getEscReader(),
*_rewriter->getEccReader(),
edge.esc,
@@ -67,6 +75,18 @@ std::vector<PrfBlock> RangePredicate::generateTags(BSONValue payload) const {
std::unique_ptr<MatchExpression> RangePredicate::rewriteToTagDisjunction(
MatchExpression* expr) const {
+ if (auto compExpr = dynamic_cast<ComparisonMatchExpression*>(expr)) {
+ auto payload = compExpr->getData();
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ // If this is a stub expression, replace expression with $alwaysTrue.
+ if (isStub(payload)) {
+ return std::make_unique<AlwaysTrueMatchExpression>();
+ }
+ return makeTagDisjunction(toBSONArray(generateTags(payload)));
+ }
+
tassert(6720900,
"Range rewrite should only be called with $between operator.",
expr->matchType() == MatchExpression::BETWEEN);
@@ -113,11 +133,14 @@ std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayl
std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayload(
boost::intrusive_ptr<Expression> fieldpath, ParsedFindRangePayload payload) const {
+ tassert(7030501,
+ "$internalFleBetween can only be generated from a non-stub payload.",
+ !payload.isStub());
auto cm = payload.maxCounter;
ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken);
std::vector<ConstDataRange> edcTokens;
- std::transform(std::make_move_iterator(payload.edges.begin()),
- std::make_move_iterator(payload.edges.end()),
+ std::transform(std::make_move_iterator(payload.edges.value().begin()),
+ std::make_move_iterator(payload.edges.value().end()),
std::back_inserter(edcTokens),
[](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); });
@@ -128,8 +151,21 @@ std::unique_ptr<ExpressionInternalFLEBetween> RangePredicate::fleBetweenFromPayl
std::unique_ptr<MatchExpression> RangePredicate::rewriteToRuntimeComparison(
MatchExpression* expr) const {
- auto between = static_cast<BetweenMatchExpression*>(expr);
- auto ffp = between->rhs();
+ BSONElement ffp;
+ if (auto compExpr = dynamic_cast<ComparisonMatchExpression*>(expr)) {
+ auto payload = compExpr->getData();
+ if (!isPayload(payload)) {
+ return nullptr;
+ }
+ // If this is a stub expression, replace expression with $alwaysTrue.
+ if (isStub(payload)) {
+ return std::make_unique<AlwaysTrueMatchExpression>();
+ }
+ ffp = payload;
+ } else {
+ auto between = static_cast<BetweenMatchExpression*>(expr);
+ ffp = between->rhs();
+ }
if (!isPayload(ffp)) {
return nullptr;
diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h
index 3fdc7db0b2b..bc2a47fa173 100644
--- a/src/mongo/db/query/fle/range_predicate.h
+++ b/src/mongo/db/query/fle/range_predicate.h
@@ -50,6 +50,11 @@ protected:
MatchExpression* expr) const override;
std::unique_ptr<Expression> rewriteToRuntimeComparison(Expression* expr) const override;
+ virtual bool isStub(BSONElement elt) const {
+ auto parsedPayload = parseFindPayload<ParsedFindRangePayload>(elt);
+ return parsedPayload.isStub();
+ }
+
private:
EncryptedBinDataType encryptedBinDataType() const override {
return EncryptedBinDataType::kFLE2FindRangePayload;
diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp
index fe100b6bc5a..38666ec93e6 100644
--- a/src/mongo/db/query/fle/range_predicate_test.cpp
+++ b/src/mongo/db/query/fle/range_predicate_test.cpp
@@ -27,9 +27,12 @@
* it in the license file.
*/
+#include "mongo/bson/bsonelement.h"
#include "mongo/crypto/fle_crypto.h"
+#include "mongo/crypto/fle_field_schema_gen.h"
#include "mongo/db/matcher/expression_expr.h"
#include "mongo/db/matcher/expression_leaf.h"
+#include "mongo/db/matcher/expression_parser.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/fle/encrypted_predicate.h"
@@ -47,15 +50,10 @@ public:
MockRangePredicate(const QueryRewriterInterface* rewriter,
TagMap tags,
std::set<StringData> encryptedFields)
- : RangePredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {}
-
- void setEncryptedTags(std::pair<StringData, int> fieldvalue, std::vector<PrfBlock> tags) {
- _encryptedFields.insert(fieldvalue.first);
- _tags[fieldvalue] = tags;
- }
-
+ : RangePredicate(rewriter) {}
bool payloadValid = true;
+ bool isStubPayload = false;
protected:
bool isPayload(const BSONElement& elt) const override {
@@ -66,27 +64,23 @@ protected:
return payloadValid;
}
+ bool isStub(BSONElement elt) const override {
+ return isStubPayload;
+ }
+
std::vector<PrfBlock> generateTags(BSONValue payload) const {
return stdx::visit(
OverloadedVisitor{[&](BSONElement p) {
- auto parsedPayload = p.Obj().firstElement();
- auto fieldName = parsedPayload.fieldNameStringData();
-
- std::vector<BSONElement> range;
- auto payloadAsArray = parsedPayload.Array();
- for (auto&& elt : payloadAsArray) {
- range.push_back(elt);
- }
-
- std::vector<PrfBlock> allTags;
- for (auto i = range[0].Number(); i <= range[1].Number(); i++) {
- ASSERT(_tags.find({fieldName, i}) != _tags.end());
- auto temp = _tags.find({fieldName, i})->second;
- for (auto tag : temp) {
- allTags.push_back(tag);
+ if (p.isABSONObj()) {
+ std::vector<PrfBlock> allTags;
+ for (auto& val : p.Array()) {
+ allTags.push_back(PrfBlock(
+ {static_cast<unsigned char>(val.safeNumberInt())}));
}
+ return allTags;
+ } else {
+ return std::vector<PrfBlock>{};
}
- return allTags;
},
[&](std::reference_wrapper<Value> v) {
if (v.get().isArray()) {
@@ -103,10 +97,6 @@ protected:
}},
payload);
}
-
-private:
- TagMap _tags;
- std::set<StringData> _encryptedFields;
};
class RangePredicateRewriteTest : public EncryptedPredicateRewriteTest {
public:
@@ -116,31 +106,57 @@ protected:
MockRangePredicate _predicate;
};
-TEST_F(RangePredicateRewriteTest, MatchRangeRewrite) {
+TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) {
RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
- int start = 1;
- int end = 3;
- StringData encField = "ssn";
+ std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
+
+ auto expCtx = make_intrusive<ExpressionContextForTest>();
- std::vector<PrfBlock> tags1 = {{1}, {2}, {3}};
- std::vector<PrfBlock> tags2 = {{4}, {5}, {6}};
- std::vector<PrfBlock> tags3 = {{7}, {8}, {9}};
+ std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"};
+ auto payload = fromjson("{x: [1, 2, 3, 4, 5, 6, 7, 8, 9]}");
- _predicate.setEncryptedTags({encField, 1}, tags1);
- _predicate.setEncryptedTags({encField, 2}, tags2);
- _predicate.setEncryptedTags({encField, 3}, tags3);
+ assertRewriteForOp<BetweenMatchExpression>(_predicate, payload.firstElement(), allTags);
+ assertRewriteForOp<GTMatchExpression>(_predicate, payload.firstElement(), allTags);
+ assertRewriteForOp<GTEMatchExpression>(_predicate, payload.firstElement(), allTags);
+ assertRewriteForOp<LTMatchExpression>(_predicate, payload.firstElement(), allTags);
+ assertRewriteForOp<LTEMatchExpression>(_predicate, payload.firstElement(), allTags);
+}
+
+TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) {
+ RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true);
std::vector<PrfBlock> allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}};
- // The field redundancy is so that we can pull out the field
- // name in the mock version of rewriteRangePayloadAsTags.
- BSONObj query =
- BSON(encField << BSON("$between" << BSON(encField << BSON_ARRAY(start << end))));
+ auto expCtx = make_intrusive<ExpressionContextForTest>();
+
+ std::vector<StringData> operators = {"$between", "$gt", "$gte", "$lte", "$lt"};
+ auto payload = fromjson("{x: [1, 2, 3, 4, 5, 6, 7, 8, 9]}");
- auto inputExpr = BetweenMatchExpression(encField, query[encField]["$between"], nullptr);
+#define ASSERT_REWRITE_TO_TRUE(T) \
+ { \
+ std::unique_ptr<MatchExpression> inputExpr = std::make_unique<T>("age", Value(0)); \
+ _predicate.isStubPayload = true; \
+ auto rewrite = _predicate.rewrite(inputExpr.get()); \
+ ASSERT_EQ(rewrite->matchType(), MatchExpression::ALWAYS_TRUE); \
+ }
- assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags)));
+ // Rewrites that would normally go to disjunctions.
+ {
+ ASSERT_REWRITE_TO_TRUE(GTMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(GTEMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(LTMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(LTEMatchExpression);
+ }
+
+ // Rewrites that would normally go to $internalFleBetween.
+ {
+ _mock.setForceEncryptedCollScanForTest();
+ ASSERT_REWRITE_TO_TRUE(GTMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(GTEMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(LTMatchExpression);
+ ASSERT_REWRITE_TO_TRUE(LTEMatchExpression);
+ }
}
TEST_F(RangePredicateRewriteTest, AggRangeRewrite) {
@@ -175,16 +191,19 @@ BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) {
FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId);
auto edges = minCoverInt32(lb, true, ub, true, min, max, 1);
- auto ffp = FLEClientCrypto::serializeFindRangePayload(indexKeyAndId, userKeyAndId, edges, 0);
+ FLE2RangeFindSpec spec(0, Fle2RangeOperator::kGt);
+ auto ffp =
+ FLEClientCrypto::serializeFindRangePayload(indexKeyAndId, userKeyAndId, edges, 0, spec);
BSONObjBuilder builder;
toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindRangePayload, ffp, &builder);
return builder.obj();
}
-std::unique_ptr<MatchExpression> generateBetweenWithFFP(StringData path, int lb, int ub) {
+template <typename T>
+std::unique_ptr<MatchExpression> generateOpWithFFP(StringData path, int lb, int ub) {
auto ffp = generateFFP(path, lb, ub, 0, 255);
- return std::make_unique<BetweenMatchExpression>(path, ffp.firstElement());
+ return std::make_unique<T>(path, ffp.firstElement());
}
std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx,
@@ -202,12 +221,6 @@ std::unique_ptr<Expression> generateBetweenWithFFP(ExpressionContext* expCtx,
TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
_mock.setForceEncryptedCollScanForTest();
- auto input = generateBetweenWithFFP("age", 23, 35);
- auto result = _predicate.rewrite(input.get());
- ASSERT(result);
- ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION);
- auto* expr = static_cast<ExprMatchExpression*>(result.get());
- auto aggExpr = expr->getExpression();
auto expected = fromjson(R"({
"$_internalFleBetween": {
"field": "$age",
@@ -241,7 +254,21 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) {
}
}
})");
- ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected);
+#define ASSERT_REWRITE_TO_INTERNAL_BETWEEN(T) \
+ { \
+ auto input = generateOpWithFFP<T>("age", 23, 35); \
+ auto result = _predicate.rewrite(input.get()); \
+ ASSERT(result); \
+ ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION); \
+ auto* expr = static_cast<ExprMatchExpression*>(result.get()); \
+ auto aggExpr = expr->getExpression(); \
+ ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected); \
+ }
+ ASSERT_REWRITE_TO_INTERNAL_BETWEEN(BetweenMatchExpression);
+ ASSERT_REWRITE_TO_INTERNAL_BETWEEN(GTMatchExpression);
+ ASSERT_REWRITE_TO_INTERNAL_BETWEEN(GTEMatchExpression);
+ ASSERT_REWRITE_TO_INTERNAL_BETWEEN(LTMatchExpression);
+ ASSERT_REWRITE_TO_INTERNAL_BETWEEN(LTEMatchExpression);
}
TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) {