From 8af29f897d967f540c60ca8fb6f38f65e6fc9620 Mon Sep 17 00:00:00 2001 From: Davis Haupt Date: Thu, 20 Oct 2022 17:31:26 +0000 Subject: SERVER-70305 support encrypted range payloads under gt/lt expressions --- src/mongo/crypto/fle_crypto.cpp | 27 +++-- src/mongo/crypto/fle_crypto.h | 12 +- src/mongo/crypto/fle_crypto_test.cpp | 13 +- src/mongo/crypto/fle_field_schema.idl | 29 ++++- .../query/fle/encrypted_predicate_test_fixtures.h | 8 ++ src/mongo/db/query/fle/range_predicate.cpp | 46 ++++++- src/mongo/db/query/fle/range_predicate.h | 5 + src/mongo/db/query/fle/range_predicate_test.cpp | 133 +++++++++++++-------- 8 files changed, 189 insertions(+), 84 deletions(-) diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp index 366d0a85432..e268701b450 100644 --- a/src/mongo/crypto/fle_crypto.cpp +++ b/src/mongo/crypto/fle_crypto.cpp @@ -1160,7 +1160,7 @@ void convertToFLE2Payload(FLEKeyVault* keyVault, auto edges = getMinCover(rangeFindSpec, ep.getSparsity().value()); auto findpayload = FLEClientCrypto::serializeFindRangePayload( - indexKey, userKey, edges, ep.getMaxContentionCounter()); + indexKey, userKey, edges, ep.getMaxContentionCounter(), rangeFindSpec); toEncryptedBinData(fieldNameToSerialize, EncryptedBinDataType::kFLE2FindRangePayload, @@ -2228,7 +2228,8 @@ FLE2FindRangePayload FLEClientCrypto::serializeFindRangePayload( FLEIndexKeyAndId indexKey, FLEUserKeyAndId userKey, const std::vector& edges, - uint64_t maxContentionFactor) { + uint64_t maxContentionFactor, + const FLE2RangeFindSpec& spec) { auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(indexKey.key); auto serverToken = FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token(indexKey.key); @@ -2265,7 +2266,8 @@ FLE2FindRangePayload FLEClientCrypto::serializeFindRangePayload( edgesInfo.setServerEncryptionToken(serverToken.toCDR()); payload.setPayload(edgesInfo); - payload.setOperatorType(StringData("$gt")); // TODO: Change for SERVER-70305 + payload.setFirstOperator(spec.getFirstOperator()); + payload.setSecondOperator(spec.getSecondOperator()); payload.setPayloadId(1234); return payload; @@ -3225,14 +3227,20 @@ ParsedFindRangePayload::ParsedFindRangePayload(ConstDataRange cdr) { encryptedType == EncryptedBinDataType::kFLE2FindRangePayload); auto payload = parseFromCDR(subCdr); + payloadId = payload.getPayloadId(); + firstOp = payload.getFirstOperator(); + secondOp = payload.getSecondOperator(); if (!payload.getPayload()) { return; } - auto& stub = payload.getPayload().get(); + edges = std::vector(); + auto& edgesRef = edges.value(); + + auto& info = payload.getPayload().value(); - for (auto const& edge : stub.getEdges()) { + for (auto const& edge : info.getEdges()) { auto escToken = FLETokenFromCDR(edge.getEscDerivedToken()); @@ -3241,16 +3249,13 @@ ParsedFindRangePayload::ParsedFindRangePayload(ConstDataRange cdr) { auto edcToken = FLETokenFromCDR(edge.getEdcDerivedToken()); - edges.push_back({edcToken, escToken, eccToken}); + edgesRef.push_back({edcToken, escToken, eccToken}); } serverToken = FLETokenFromCDR( - stub.getServerEncryptionToken()); + info.getServerEncryptionToken()); - maxCounter = stub.getMaxCounter(); - - payloadId = payload.getPayloadId(); - operatorType = payload.getOperatorType().toString(); + maxCounter = info.getMaxCounter(); } diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h index 879efcae1a9..d9409b4e036 100644 --- a/src/mongo/crypto/fle_crypto.h +++ b/src/mongo/crypto/fle_crypto.h @@ -583,7 +583,8 @@ public: static FLE2FindRangePayload serializeFindRangePayload(FLEIndexKeyAndId indexKey, FLEUserKeyAndId userKey, const std::vector& edges, - uint64_t maxContentionFactor); + uint64_t maxContentionFactor, + const FLE2RangeFindSpec& spec); /** * Generates a client-side payload that is sent to the server. @@ -1167,10 +1168,11 @@ struct FLEFindEdgeTokenSet { }; struct ParsedFindRangePayload { - std::vector edges; + boost::optional> edges; ServerDataEncryptionLevel1Token serverToken; - std::string operatorType; + Fle2RangeOperator firstOp; + boost::optional secondOp; std::int32_t payloadId; std::int64_t maxCounter; @@ -1178,6 +1180,10 @@ struct ParsedFindRangePayload { explicit ParsedFindRangePayload(BSONElement fleFindRangePayload); explicit ParsedFindRangePayload(const Value& fleFindRangePayload); explicit ParsedFindRangePayload(ConstDataRange cdr); + + bool isStub() { + return !edges.has_value(); + } }; diff --git a/src/mongo/crypto/fle_crypto_test.cpp b/src/mongo/crypto/fle_crypto_test.cpp index 5d5c676d22c..e99f4c854b9 100644 --- a/src/mongo/crypto/fle_crypto_test.cpp +++ b/src/mongo/crypto/fle_crypto_test.cpp @@ -773,8 +773,9 @@ std::vector generatePlaceholder( findSpec.setEdgesInfo(edgesInfo); - // TODO: change in SERVER-70305 - findSpec.setOperatorType(StringData("gt")); + // TODO: SERVER-70302 update query analysis to generate payloads in gt/lt pairs. + findSpec.setFirstOperator(Fle2RangeOperator::kGt); + findSpec.setPayloadId(1234); auto findDoc = BSON("s" << findSpec.toBSON()); @@ -2598,8 +2599,8 @@ void assertMinCoverResult(A lb, FLE2RangeFindSpec spec; spec.setEdgesInfo(edgesInfo); - // TODO: change in SERVER-70305 - spec.setOperatorType(StringData("gt")); + // TODO: SERVER-70302 update query analysis to generate payloads in gt/lt pairs. + spec.setFirstOperator(Fle2RangeOperator::kGt); spec.setPayloadId(1234); auto result = getMinCover(spec, sparsity); @@ -3530,8 +3531,8 @@ DEATH_TEST_REGEX(MinCoverInterfaceTest, Error_MinMaxTypeMismatch, "Tripwire asse FLE2RangeFindSpec spec; spec.setEdgesInfo(edgesInfo); - // TODO: change in SERVER-70305 - spec.setOperatorType(StringData("gt")); + // TODO: SERVER-70302 update query analysis to generate payloads in gt/lt pairs. + spec.setFirstOperator(Fle2RangeOperator::kGt); spec.setPayloadId(1234); diff --git a/src/mongo/crypto/fle_field_schema.idl b/src/mongo/crypto/fle_field_schema.idl index 34fe6211c8e..65b274dd3e7 100644 --- a/src/mongo/crypto/fle_field_schema.idl +++ b/src/mongo/crypto/fle_field_schema.idl @@ -91,6 +91,15 @@ enums: kInsert: 1 kFind: 2 + Fle2RangeOperator: + description: "Enum representing valid range operators that an encrypted payload can be under." + type: int + values: + kGt: 1 + kGte: 2 + kLt: 3 + kLte: 4 + types: encrypted_numeric_element: bson_serialization_type: @@ -317,10 +326,14 @@ structs: description: "Id of payload - must be paired with another payload" type: safeInt optional: false - operatorType: - description: "One of gt, lt, gte, lte" - type: string + firstOperator: + description: "First query operator for which this payload was generated." + type: Fle2RangeOperator optional: false + secondOperator: + description: "Second query operator for which this payload was generated. Only populated for two-sided ranges." + type: Fle2RangeOperator + optional: true EncryptionInformation: description: "Implements Encryption Information which includes the schema for Queryable Encryption that is consumed by query_analysis, queries and write_ops" @@ -396,10 +409,14 @@ structs: description: "Id of payload - must be paired with another payload" type: safeInt optional: false - operatorType: - description: "One of gt, lt, gte, lte" - type: string + firstOperator: + description: "First query operator for which this payload was generated." + type: Fle2RangeOperator optional: false + secondOperator: + description: "Second query operator for which this payload was generated. Only populated for two-sided ranges." + type: Fle2RangeOperator + optional: true FLE2RangeInsertSpec: description: "Range insert specification that is encoded inside of a FLE2EncryptionPlaceholder." diff --git a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h index dfc6aa2b9b0..d521b6f9989 100644 --- a/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h +++ b/src/mongo/db/query/fle/encrypted_predicate_test_fixtures.h @@ -100,6 +100,14 @@ public: static_cast(expected.get())->serialize()); } + template + void assertRewriteForOp(const EncryptedPredicate& pred, + BSONElement rhs, + std::vector allTags) { + auto inputExpr = T("age", rhs); + assertRewriteToTags(pred, &inputExpr, toBSONArray(std::move(allTags))); + } + protected: MockServerRewrite _mock{}; ExpressionContextForTest _expCtx; diff --git a/src/mongo/db/query/fle/range_predicate.cpp b/src/mongo/db/query/fle/range_predicate.cpp index d4bfd1f6bc2..8c083b07dc1 100644 --- a/src/mongo/db/query/fle/range_predicate.cpp +++ b/src/mongo/db/query/fle/range_predicate.cpp @@ -34,6 +34,7 @@ #include "mongo/crypto/encryption_fields_gen.h" #include "mongo/crypto/fle_crypto.h" #include "mongo/crypto/fle_tags.h" +#include "mongo/db/matcher/expression_always_boolean.h" #include "mongo/db/matcher/expression_expr.h" #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/pipeline/expression.h" @@ -44,6 +45,12 @@ namespace mongo::fle { REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(BETWEEN, RangePredicate, gFeatureFlagFLE2Range); + +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(GT, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(GTE, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(LT, RangePredicate, gFeatureFlagFLE2Range); +REGISTER_ENCRYPTED_MATCH_PREDICATE_REWRITE_WITH_FLAG(LTE, RangePredicate, gFeatureFlagFLE2Range); + REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, RangePredicate, gFeatureFlagFLE2Range); @@ -51,7 +58,8 @@ REGISTER_ENCRYPTED_AGG_PREDICATE_REWRITE_WITH_FLAG(ExpressionBetween, std::vector RangePredicate::generateTags(BSONValue payload) const { auto parsedPayload = parseFindPayload(payload); std::vector tags; - for (auto& edge : parsedPayload.edges) { + tassert(7030500, "Must generate tags from a non-stub payload.", !parsedPayload.isStub()); + for (auto& edge : parsedPayload.edges.value()) { auto tagsForEdge = readTags(*_rewriter->getEscReader(), *_rewriter->getEccReader(), edge.esc, @@ -67,6 +75,18 @@ std::vector RangePredicate::generateTags(BSONValue payload) const { std::unique_ptr RangePredicate::rewriteToTagDisjunction( MatchExpression* expr) const { + if (auto compExpr = dynamic_cast(expr)) { + auto payload = compExpr->getData(); + if (!isPayload(payload)) { + return nullptr; + } + // If this is a stub expression, replace expression with $alwaysTrue. + if (isStub(payload)) { + return std::make_unique(); + } + return makeTagDisjunction(toBSONArray(generateTags(payload))); + } + tassert(6720900, "Range rewrite should only be called with $between operator.", expr->matchType() == MatchExpression::BETWEEN); @@ -113,11 +133,14 @@ std::unique_ptr RangePredicate::fleBetweenFromPayl std::unique_ptr RangePredicate::fleBetweenFromPayload( boost::intrusive_ptr fieldpath, ParsedFindRangePayload payload) const { + tassert(7030501, + "$internalFleBetween can only be generated from a non-stub payload.", + !payload.isStub()); auto cm = payload.maxCounter; ServerDataEncryptionLevel1Token serverToken = std::move(payload.serverToken); std::vector edcTokens; - std::transform(std::make_move_iterator(payload.edges.begin()), - std::make_move_iterator(payload.edges.end()), + std::transform(std::make_move_iterator(payload.edges.value().begin()), + std::make_move_iterator(payload.edges.value().end()), std::back_inserter(edcTokens), [](FLEFindEdgeTokenSet&& edge) { return edge.edc.toCDR(); }); @@ -128,8 +151,21 @@ std::unique_ptr RangePredicate::fleBetweenFromPayl std::unique_ptr RangePredicate::rewriteToRuntimeComparison( MatchExpression* expr) const { - auto between = static_cast(expr); - auto ffp = between->rhs(); + BSONElement ffp; + if (auto compExpr = dynamic_cast(expr)) { + auto payload = compExpr->getData(); + if (!isPayload(payload)) { + return nullptr; + } + // If this is a stub expression, replace expression with $alwaysTrue. + if (isStub(payload)) { + return std::make_unique(); + } + ffp = payload; + } else { + auto between = static_cast(expr); + ffp = between->rhs(); + } if (!isPayload(ffp)) { return nullptr; diff --git a/src/mongo/db/query/fle/range_predicate.h b/src/mongo/db/query/fle/range_predicate.h index 3fdc7db0b2b..bc2a47fa173 100644 --- a/src/mongo/db/query/fle/range_predicate.h +++ b/src/mongo/db/query/fle/range_predicate.h @@ -50,6 +50,11 @@ protected: MatchExpression* expr) const override; std::unique_ptr rewriteToRuntimeComparison(Expression* expr) const override; + virtual bool isStub(BSONElement elt) const { + auto parsedPayload = parseFindPayload(elt); + return parsedPayload.isStub(); + } + private: EncryptedBinDataType encryptedBinDataType() const override { return EncryptedBinDataType::kFLE2FindRangePayload; diff --git a/src/mongo/db/query/fle/range_predicate_test.cpp b/src/mongo/db/query/fle/range_predicate_test.cpp index fe100b6bc5a..38666ec93e6 100644 --- a/src/mongo/db/query/fle/range_predicate_test.cpp +++ b/src/mongo/db/query/fle/range_predicate_test.cpp @@ -27,9 +27,12 @@ * it in the license file. */ +#include "mongo/bson/bsonelement.h" #include "mongo/crypto/fle_crypto.h" +#include "mongo/crypto/fle_field_schema_gen.h" #include "mongo/db/matcher/expression_expr.h" #include "mongo/db/matcher/expression_leaf.h" +#include "mongo/db/matcher/expression_parser.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/fle/encrypted_predicate.h" @@ -47,15 +50,10 @@ public: MockRangePredicate(const QueryRewriterInterface* rewriter, TagMap tags, std::set encryptedFields) - : RangePredicate(rewriter), _tags(tags), _encryptedFields(encryptedFields) {} - - void setEncryptedTags(std::pair fieldvalue, std::vector tags) { - _encryptedFields.insert(fieldvalue.first); - _tags[fieldvalue] = tags; - } - + : RangePredicate(rewriter) {} bool payloadValid = true; + bool isStubPayload = false; protected: bool isPayload(const BSONElement& elt) const override { @@ -66,27 +64,23 @@ protected: return payloadValid; } + bool isStub(BSONElement elt) const override { + return isStubPayload; + } + std::vector generateTags(BSONValue payload) const { return stdx::visit( OverloadedVisitor{[&](BSONElement p) { - auto parsedPayload = p.Obj().firstElement(); - auto fieldName = parsedPayload.fieldNameStringData(); - - std::vector range; - auto payloadAsArray = parsedPayload.Array(); - for (auto&& elt : payloadAsArray) { - range.push_back(elt); - } - - std::vector allTags; - for (auto i = range[0].Number(); i <= range[1].Number(); i++) { - ASSERT(_tags.find({fieldName, i}) != _tags.end()); - auto temp = _tags.find({fieldName, i})->second; - for (auto tag : temp) { - allTags.push_back(tag); + if (p.isABSONObj()) { + std::vector allTags; + for (auto& val : p.Array()) { + allTags.push_back(PrfBlock( + {static_cast(val.safeNumberInt())})); } + return allTags; + } else { + return std::vector{}; } - return allTags; }, [&](std::reference_wrapper v) { if (v.get().isArray()) { @@ -103,10 +97,6 @@ protected: }}, payload); } - -private: - TagMap _tags; - std::set _encryptedFields; }; class RangePredicateRewriteTest : public EncryptedPredicateRewriteTest { public: @@ -116,31 +106,57 @@ protected: MockRangePredicate _predicate; }; -TEST_F(RangePredicateRewriteTest, MatchRangeRewrite) { +TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_NoStub) { RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); - int start = 1; - int end = 3; - StringData encField = "ssn"; + std::vector allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}}; + + auto expCtx = make_intrusive(); - std::vector tags1 = {{1}, {2}, {3}}; - std::vector tags2 = {{4}, {5}, {6}}; - std::vector tags3 = {{7}, {8}, {9}}; + std::vector operators = {"$between", "$gt", "$gte", "$lte", "$lt"}; + auto payload = fromjson("{x: [1, 2, 3, 4, 5, 6, 7, 8, 9]}"); - _predicate.setEncryptedTags({encField, 1}, tags1); - _predicate.setEncryptedTags({encField, 2}, tags2); - _predicate.setEncryptedTags({encField, 3}, tags3); + assertRewriteForOp(_predicate, payload.firstElement(), allTags); + assertRewriteForOp(_predicate, payload.firstElement(), allTags); + assertRewriteForOp(_predicate, payload.firstElement(), allTags); + assertRewriteForOp(_predicate, payload.firstElement(), allTags); + assertRewriteForOp(_predicate, payload.firstElement(), allTags); +} + +TEST_F(RangePredicateRewriteTest, MatchRangeRewrite_Stub) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2Range", true); std::vector allTags = {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}, {9}}; - // The field redundancy is so that we can pull out the field - // name in the mock version of rewriteRangePayloadAsTags. - BSONObj query = - BSON(encField << BSON("$between" << BSON(encField << BSON_ARRAY(start << end)))); + auto expCtx = make_intrusive(); + + std::vector operators = {"$between", "$gt", "$gte", "$lte", "$lt"}; + auto payload = fromjson("{x: [1, 2, 3, 4, 5, 6, 7, 8, 9]}"); - auto inputExpr = BetweenMatchExpression(encField, query[encField]["$between"], nullptr); +#define ASSERT_REWRITE_TO_TRUE(T) \ + { \ + std::unique_ptr inputExpr = std::make_unique("age", Value(0)); \ + _predicate.isStubPayload = true; \ + auto rewrite = _predicate.rewrite(inputExpr.get()); \ + ASSERT_EQ(rewrite->matchType(), MatchExpression::ALWAYS_TRUE); \ + } - assertRewriteToTags(_predicate, &inputExpr, toBSONArray(std::move(allTags))); + // Rewrites that would normally go to disjunctions. + { + ASSERT_REWRITE_TO_TRUE(GTMatchExpression); + ASSERT_REWRITE_TO_TRUE(GTEMatchExpression); + ASSERT_REWRITE_TO_TRUE(LTMatchExpression); + ASSERT_REWRITE_TO_TRUE(LTEMatchExpression); + } + + // Rewrites that would normally go to $internalFleBetween. + { + _mock.setForceEncryptedCollScanForTest(); + ASSERT_REWRITE_TO_TRUE(GTMatchExpression); + ASSERT_REWRITE_TO_TRUE(GTEMatchExpression); + ASSERT_REWRITE_TO_TRUE(LTMatchExpression); + ASSERT_REWRITE_TO_TRUE(LTEMatchExpression); + } } TEST_F(RangePredicateRewriteTest, AggRangeRewrite) { @@ -175,16 +191,19 @@ BSONObj generateFFP(StringData path, int lb, int ub, int min, int max) { FLEUserKeyAndId userKeyAndId(userKey.data, indexKeyId); auto edges = minCoverInt32(lb, true, ub, true, min, max, 1); - auto ffp = FLEClientCrypto::serializeFindRangePayload(indexKeyAndId, userKeyAndId, edges, 0); + FLE2RangeFindSpec spec(0, Fle2RangeOperator::kGt); + auto ffp = + FLEClientCrypto::serializeFindRangePayload(indexKeyAndId, userKeyAndId, edges, 0, spec); BSONObjBuilder builder; toEncryptedBinData(path, EncryptedBinDataType::kFLE2FindRangePayload, ffp, &builder); return builder.obj(); } -std::unique_ptr generateBetweenWithFFP(StringData path, int lb, int ub) { +template +std::unique_ptr generateOpWithFFP(StringData path, int lb, int ub) { auto ffp = generateFFP(path, lb, ub, 0, 255); - return std::make_unique(path, ffp.firstElement()); + return std::make_unique(path, ffp.firstElement()); } std::unique_ptr generateBetweenWithFFP(ExpressionContext* expCtx, @@ -202,12 +221,6 @@ std::unique_ptr generateBetweenWithFFP(ExpressionContext* expCtx, TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { _mock.setForceEncryptedCollScanForTest(); - auto input = generateBetweenWithFFP("age", 23, 35); - auto result = _predicate.rewrite(input.get()); - ASSERT(result); - ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION); - auto* expr = static_cast(result.get()); - auto aggExpr = expr->getExpression(); auto expected = fromjson(R"({ "$_internalFleBetween": { "field": "$age", @@ -241,7 +254,21 @@ TEST_F(RangePredicateRewriteTest, CollScanRewriteMatch) { } } })"); - ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected); +#define ASSERT_REWRITE_TO_INTERNAL_BETWEEN(T) \ + { \ + auto input = generateOpWithFFP("age", 23, 35); \ + auto result = _predicate.rewrite(input.get()); \ + ASSERT(result); \ + ASSERT_EQ(result->matchType(), MatchExpression::EXPRESSION); \ + auto* expr = static_cast(result.get()); \ + auto aggExpr = expr->getExpression(); \ + ASSERT_BSONOBJ_EQ(aggExpr->serialize(false).getDocument().toBson(), expected); \ + } + ASSERT_REWRITE_TO_INTERNAL_BETWEEN(BetweenMatchExpression); + ASSERT_REWRITE_TO_INTERNAL_BETWEEN(GTMatchExpression); + ASSERT_REWRITE_TO_INTERNAL_BETWEEN(GTEMatchExpression); + ASSERT_REWRITE_TO_INTERNAL_BETWEEN(LTMatchExpression); + ASSERT_REWRITE_TO_INTERNAL_BETWEEN(LTEMatchExpression); } TEST_F(RangePredicateRewriteTest, CollScanRewriteAgg) { -- cgit v1.2.1