diff options
author | Erwin Pe <erwin.pe@mongodb.com> | 2023-02-14 14:04:29 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2023-02-14 16:13:43 +0000 |
commit | cfbe26d87f74d128a33003bdd3aafd5a97a4a46e (patch) | |
tree | 13fc67ec9dfa7cae7eb115e61bc1b88cc57965ea | |
parent | 7c98e8b07b47e57f7863d6b2b0d23516a7cbe8a1 (diff) | |
download | mongo-cfbe26d87f74d128a33003bdd3aafd5a97a4a46e.tar.gz |
SERVER-72919 Implement v2 changes to QE inserts
-rw-r--r-- | buildscripts/resmokelib/core/programs.py | 5 | ||||
-rw-r--r-- | jstests/fle2/libs/encrypted_client_util.js | 56 | ||||
-rw-r--r-- | src/mongo/crypto/fle_crypto.cpp | 542 | ||||
-rw-r--r-- | src/mongo/crypto/fle_crypto.h | 64 | ||||
-rw-r--r-- | src/mongo/db/fle_crud.cpp | 117 | ||||
-rw-r--r-- | src/mongo/db/fle_crud_test.cpp | 179 | ||||
-rw-r--r-- | src/mongo/db/matcher/expression_type.h | 2 | ||||
-rw-r--r-- | src/mongo/db/matcher/expression_type_test.cpp | 2 | ||||
-rw-r--r-- | src/mongo/shell/shell_options.cpp | 1 |
9 files changed, 921 insertions, 47 deletions
diff --git a/buildscripts/resmokelib/core/programs.py b/buildscripts/resmokelib/core/programs.py index 696a8d2e922..1b4dae7213a 100644 --- a/buildscripts/resmokelib/core/programs.py +++ b/buildscripts/resmokelib/core/programs.py @@ -287,6 +287,11 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam if "host" in kwargs: kwargs.pop("host") + # if featureFlagFLE2ProtocolVersion2 is enabled in setParameter, enable it in the shell also + # TODO: SERVER-73303 remove once v2 is enabled by default + if mongod_set_parameters.get("featureFlagFLE2ProtocolVersion2"): + args.append("--setShellParameter=featureFlagFLE2ProtocolVersion2=true") + # Apply the rest of the command line arguments. _apply_kwargs(args, kwargs) diff --git a/jstests/fle2/libs/encrypted_client_util.js b/jstests/fle2/libs/encrypted_client_util.js index 1141ce6af61..bdd27804862 100644 --- a/jstests/fle2/libs/encrypted_client_util.js +++ b/jstests/fle2/libs/encrypted_client_util.js @@ -178,9 +178,12 @@ var EncryptedClient = class { Object.extend(tenantOption, {"$tenant": dollarTenant}); } assert.commandWorked(this._edb.createCollection(ef.escCollection, tenantOption)); - assert.commandWorked(this._edb.createCollection(ef.eccCollection, tenantOption)); assert.commandWorked(this._edb.createCollection(ef.ecocCollection, tenantOption)); + // TODO: SERVER-73303 remove once v2 is enabled by default + if (!isFLE2ProtocolVersion2Enabled()) { + assert.commandWorked(this._edb.createCollection(ef.eccCollection, tenantOption)); + } return res; } @@ -234,10 +237,13 @@ var EncryptedClient = class { expectedEsc, `ESC document count is wrong: Actual ${actualEsc} vs Expected ${expectedEsc}`); - const actualEcc = countDocuments(sessionDB, ef.eccCollection, tenantId); - assert.eq(actualEcc, - expectedEcc, - `ECC document count is wrong: Actual ${actualEcc} vs Expected ${expectedEcc}`); + if (!isFLE2ProtocolVersion2Enabled()) { + const actualEcc = countDocuments(sessionDB, ef.eccCollection, tenantId); + assert.eq( + actualEcc, + expectedEcc, + `ECC document count is wrong: Actual ${actualEcc} vs Expected ${expectedEcc}`); + } const actualEcoc = countDocuments(sessionDB, ef.ecocCollection, tenantId); assert.eq(actualEcoc, @@ -431,6 +437,15 @@ function isFLE2RangeEnabled() { return typeof (testingFLE2Range) !== "undefined" && testingFLE2Range && (TestData == undefined || TestData.setParameters.featureFlagFLE2Range); } + +/** + * @returns Returns true if featureFlagFLE2ProtocolVersion2 is enabled + */ +function isFLE2ProtocolVersion2Enabled() { + return typeof (TestData) !== "undefined" && + TestData.setParameters.featureFlagFLE2ProtocolVersion2; +} + /** * Assert a field is an indexed encrypted field. That includes both * equality and range @@ -440,8 +455,15 @@ function isFLE2RangeEnabled() { function assertIsIndexedEncryptedField(value) { assert(value instanceof BinData, "Expected BinData, found: " + value); assert.eq(value.subtype(), 6, "Expected Encrypted bindata: " + value); - assert(value.hex().startsWith("07") || value.hex().startsWith("09"), - "Expected subtype 7 or 9 but found the wrong type: " + value.hex()); + + // TODO: SERVER-73303 remove once v2 is enabled by default + if (!isFLE2ProtocolVersion2Enabled()) { + assert(value.hex().startsWith("07") || value.hex().startsWith("09"), + "Expected subtype 7 or 9 but found the wrong type: " + value.hex()); + return; + } + assert(value.hex().startsWith("0e") || value.hex().startsWith("0f"), + "Expected subtype 14 or 15 but found the wrong type: " + value.hex()); } /** @@ -452,8 +474,14 @@ function assertIsIndexedEncryptedField(value) { function assertIsEqualityIndexedEncryptedField(value) { assert(value instanceof BinData, "Expected BinData, found: " + value); assert.eq(value.subtype(), 6, "Expected Encrypted bindata: " + value); - assert(value.hex().startsWith("07"), - "Expected subtype 7 but found the wrong type: " + value.hex()); + // TODO: SERVER-73303 remove once v2 is enabled by default + if (!isFLE2ProtocolVersion2Enabled()) { + assert(value.hex().startsWith("07"), + "Expected subtype 7 but found the wrong type: " + value.hex()); + return; + } + assert(value.hex().startsWith("0e"), + "Expected subtype 14 but found the wrong type: " + value.hex()); } /** @@ -464,8 +492,14 @@ function assertIsEqualityIndexedEncryptedField(value) { function assertIsRangeIndexedEncryptedField(value) { assert(value instanceof BinData, "Expected BinData, found: " + value); assert.eq(value.subtype(), 6, "Expected Encrypted bindata: " + value); - assert(value.hex().startsWith("09"), - "Expected subtype 9 but found the wrong type: " + value.hex()); + // TODO: SERVER-73303 remove once v2 is enabled by default + if (!isFLE2ProtocolVersion2Enabled()) { + assert(value.hex().startsWith("09"), + "Expected subtype 9 but found the wrong type: " + value.hex()); + return; + } + assert(value.hex().startsWith("0f"), + "Expected subtype 15 but found the wrong type: " + value.hex()); } /** diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp index 37fb30b1532..55e179008ee 100644 --- a/src/mongo/crypto/fle_crypto.cpp +++ b/src/mongo/crypto/fle_crypto.cpp @@ -856,6 +856,7 @@ StatusWith<std::vector<uint8_t>> KeyIdAndValue::decrypt(FLEUserKey userKey, */ class EDCClientPayload { public: + // TODO: SERVER-73303 delete v1 functions when v2 is enabled by default static FLE2InsertUpdatePayload parseInsertUpdatePayload(ConstDataRange cdr); static FLE2InsertUpdatePayload serializeInsertUpdatePayload(FLEIndexKeyAndId indexKey, @@ -869,6 +870,18 @@ public: FLE2RangeInsertSpec spec, uint8_t sparsity, uint64_t contentionFactor); + + static FLE2InsertUpdatePayloadV2 parseInsertUpdatePayloadV2(ConstDataRange cdr); + static FLE2InsertUpdatePayloadV2 serializeInsertUpdatePayloadV2(FLEIndexKeyAndId indexKey, + FLEUserKeyAndId userKey, + BSONElement element, + uint64_t contentionFactor); + static FLE2InsertUpdatePayloadV2 serializeInsertUpdatePayloadV2ForRange( + FLEIndexKeyAndId indexKey, + FLEUserKeyAndId userKey, + FLE2RangeInsertSpec spec, + uint8_t sparsity, + uint64_t contentionFactor); }; @@ -876,6 +889,10 @@ FLE2InsertUpdatePayload EDCClientPayload::parseInsertUpdatePayload(ConstDataRang return parseFromCDR<FLE2InsertUpdatePayload>(cdr); } +FLE2InsertUpdatePayloadV2 EDCClientPayload::parseInsertUpdatePayloadV2(ConstDataRange cdr) { + return parseFromCDR<FLE2InsertUpdatePayloadV2>(cdr); +} + FLE2InsertUpdatePayload EDCClientPayload::serializeInsertUpdatePayload(FLEIndexKeyAndId indexKey, FLEUserKeyAndId userKey, BSONElement element, @@ -931,6 +948,58 @@ FLE2InsertUpdatePayload EDCClientPayload::serializeInsertUpdatePayload(FLEIndexK return iupayload; } +FLE2InsertUpdatePayloadV2 EDCClientPayload::serializeInsertUpdatePayloadV2( + FLEIndexKeyAndId indexKey, + FLEUserKeyAndId userKey, + BSONElement element, + uint64_t contentionFactor) { + auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); + + auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(indexKey.key); + auto serverEncryptToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token(indexKey.key); + auto serverDerivationToken = + FLELevel1TokenGenerator::generateServerTokenDerivationLevel1Token(indexKey.key); + + auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); + auto escToken = FLECollectionTokenGenerator::generateESCToken(collectionToken); + auto ecocToken = FLECollectionTokenGenerator::generateECOCToken(collectionToken); + auto serverDerivedFromDataToken = + FLEDerivedFromDataTokenGenerator::generateServerDerivedFromDataToken(serverDerivationToken, + value); + EDCDerivedFromDataToken edcDataToken = + FLEDerivedFromDataTokenGenerator::generateEDCDerivedFromDataToken(edcToken, value); + ESCDerivedFromDataToken escDataToken = + FLEDerivedFromDataTokenGenerator::generateESCDerivedFromDataToken(escToken, value); + + EDCDerivedFromDataTokenAndContentionFactorToken edcDataCounterToken = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateEDCDerivedFromDataTokenAndContentionFactorToken(edcDataToken, contentionFactor); + ESCDerivedFromDataTokenAndContentionFactorToken escDataCounterToken = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateESCDerivedFromDataTokenAndContentionFactorToken(escDataToken, contentionFactor); + + FLE2InsertUpdatePayloadV2 iupayload; + + iupayload.setEdcDerivedToken(edcDataCounterToken.toCDR()); + iupayload.setEscDerivedToken(escDataCounterToken.toCDR()); + iupayload.setServerEncryptionToken(serverEncryptToken.toCDR()); + iupayload.setServerDerivedFromDataToken(serverDerivedFromDataToken.toCDR()); + + auto swEncryptedTokens = + EncryptedStateCollectionTokensV2(escDataCounterToken).serialize(ecocToken); + uassertStatusOK(swEncryptedTokens); + iupayload.setEncryptedTokens(swEncryptedTokens.getValue()); + + auto swCipherText = KeyIdAndValue::serialize(userKey, value); + uassertStatusOK(swCipherText); + iupayload.setValue(swCipherText.getValue()); + iupayload.setType(element.type()); + iupayload.setIndexKeyId(indexKey.keyId); + iupayload.setContentionFactor(contentionFactor); + + return iupayload; +} std::unique_ptr<Edges> getEdges(FLE2RangeInsertSpec spec, int sparsity) { auto element = spec.getValue().getElement(); @@ -1007,6 +1076,7 @@ std::unique_ptr<Edges> getEdges(FLE2RangeInsertSpec spec, int sparsity) { } } +// TODO: SERVER-73303 delete when v2 is enabled by default std::vector<EdgeTokenSet> getEdgeTokenSet(FLE2RangeInsertSpec spec, int sparsity, uint64_t contentionFactor, @@ -1060,6 +1130,55 @@ std::vector<EdgeTokenSet> getEdgeTokenSet(FLE2RangeInsertSpec spec, return tokens; } +std::vector<EdgeTokenSetV2> getEdgeTokenSet( + FLE2RangeInsertSpec spec, + int sparsity, + uint64_t contentionFactor, + const EDCToken& edcToken, + const ESCToken& escToken, + const ECOCToken& ecocToken, + const ServerTokenDerivationLevel1Token& serverDerivationToken) { + const auto edges = getEdges(std::move(spec), sparsity); + const auto edgesList = edges->get(); + + std::vector<EdgeTokenSetV2> tokens; + + for (const auto& edge : edgesList) { + ConstDataRange cdr(edge.rawData(), edge.size()); + + EDCDerivedFromDataToken edcDatakey = + FLEDerivedFromDataTokenGenerator::generateEDCDerivedFromDataToken(edcToken, cdr); + ESCDerivedFromDataToken escDatakey = + FLEDerivedFromDataTokenGenerator::generateESCDerivedFromDataToken(escToken, cdr); + + EDCDerivedFromDataTokenAndContentionFactorToken edcDataCounterkey = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateEDCDerivedFromDataTokenAndContentionFactorToken(edcDatakey, + contentionFactor); + ESCDerivedFromDataTokenAndContentionFactorToken escDataCounterkey = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateESCDerivedFromDataTokenAndContentionFactorToken(escDatakey, + contentionFactor); + ServerDerivedFromDataToken serverDatakey = + FLEDerivedFromDataTokenGenerator::generateServerDerivedFromDataToken( + serverDerivationToken, cdr); + + EdgeTokenSetV2 ets; + + ets.setEdcDerivedToken(edcDataCounterkey.toCDR()); + ets.setEscDerivedToken(escDataCounterkey.toCDR()); + ets.setServerDerivedFromDataToken(serverDatakey.toCDR()); + + auto swEncryptedTokens = + EncryptedStateCollectionTokensV2(escDataCounterkey).serialize(ecocToken); + uassertStatusOK(swEncryptedTokens); + ets.setEncryptedTokens(swEncryptedTokens.getValue()); + + tokens.push_back(ets); + } + + return tokens; +} FLE2InsertUpdatePayload EDCClientPayload::serializeInsertUpdatePayloadForRange( FLEIndexKeyAndId indexKey, @@ -1125,6 +1244,69 @@ FLE2InsertUpdatePayload EDCClientPayload::serializeInsertUpdatePayloadForRange( return iupayload; } +FLE2InsertUpdatePayloadV2 EDCClientPayload::serializeInsertUpdatePayloadV2ForRange( + FLEIndexKeyAndId indexKey, + FLEUserKeyAndId userKey, + FLE2RangeInsertSpec spec, + uint8_t sparsity, + uint64_t contentionFactor) { + auto element = spec.getValue().getElement(); + auto value = ConstDataRange(element.value(), element.value() + element.valuesize()); + + auto collectionToken = FLELevel1TokenGenerator::generateCollectionsLevel1Token(indexKey.key); + auto serverEncryptToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token(indexKey.key); + auto serverDerivationToken = + FLELevel1TokenGenerator::generateServerTokenDerivationLevel1Token(indexKey.key); + + auto edcToken = FLECollectionTokenGenerator::generateEDCToken(collectionToken); + auto escToken = FLECollectionTokenGenerator::generateESCToken(collectionToken); + auto ecocToken = FLECollectionTokenGenerator::generateECOCToken(collectionToken); + auto serverDerivedFromDataToken = + FLEDerivedFromDataTokenGenerator::generateServerDerivedFromDataToken(serverDerivationToken, + value); + + EDCDerivedFromDataToken edcDatakey = + FLEDerivedFromDataTokenGenerator::generateEDCDerivedFromDataToken(edcToken, value); + ESCDerivedFromDataToken escDatakey = + FLEDerivedFromDataTokenGenerator::generateESCDerivedFromDataToken(escToken, value); + + EDCDerivedFromDataTokenAndContentionFactorToken edcDataCounterkey = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateEDCDerivedFromDataTokenAndContentionFactorToken(edcDatakey, contentionFactor); + ESCDerivedFromDataTokenAndContentionFactorToken escDataCounterkey = + FLEDerivedFromDataTokenAndContentionFactorTokenGenerator:: + generateESCDerivedFromDataTokenAndContentionFactorToken(escDatakey, contentionFactor); + + FLE2InsertUpdatePayloadV2 iupayload; + + iupayload.setEdcDerivedToken(edcDataCounterkey.toCDR()); + iupayload.setEscDerivedToken(escDataCounterkey.toCDR()); + iupayload.setServerEncryptionToken(serverEncryptToken.toCDR()); + iupayload.setServerDerivedFromDataToken(serverDerivedFromDataToken.toCDR()); + + auto swEncryptedTokens = + EncryptedStateCollectionTokensV2(escDataCounterkey).serialize(ecocToken); + uassertStatusOK(swEncryptedTokens); + iupayload.setEncryptedTokens(swEncryptedTokens.getValue()); + + auto swCipherText = KeyIdAndValue::serialize(userKey, value); + uassertStatusOK(swCipherText); + iupayload.setValue(swCipherText.getValue()); + iupayload.setType(element.type()); + iupayload.setIndexKeyId(indexKey.keyId); + iupayload.setContentionFactor(contentionFactor); + + auto edgeTokenSet = getEdgeTokenSet( + spec, sparsity, contentionFactor, edcToken, escToken, ecocToken, serverDerivationToken); + + if (!edgeTokenSet.empty()) { + iupayload.setEdgeTokenSet(edgeTokenSet); + } + + return iupayload; +} + /** * Lightweight class to build a singly linked list of field names to represent the current field * name @@ -1329,6 +1511,19 @@ void convertToFLE2Payload(FLEKeyVault* keyVault, isFLE2EqualityIndexedSupportedType(el.type())); if (ep.getType() == Fle2PlaceholderType::kInsert) { + + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled( + serverGlobalParams.featureCompatibility)) { + auto iupayload = EDCClientPayload::serializeInsertUpdatePayloadV2( + indexKey, userKey, el, contentionFactor(ep)); + toEncryptedBinData(fieldNameToSerialize, + EncryptedBinDataType::kFLE2InsertUpdatePayloadV2, + iupayload, + builder); + return; + } + + // TODO: SERVER-73303 delete when v2 is enabled by default auto iupayload = EDCClientPayload::serializeInsertUpdatePayload( indexKey, userKey, el, contentionFactor(ep)); @@ -1361,6 +1556,22 @@ void convertToFLE2Payload(FLEKeyVault* keyVault, << "' is not a valid type for Queryable Encryption Range", isFLE2RangeIndexedSupportedType(elRange.type())); + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled( + serverGlobalParams.featureCompatibility)) { + auto iupayload = EDCClientPayload::serializeInsertUpdatePayloadV2ForRange( + indexKey, + userKey, + rangeInsertSpec, + ep.getSparsity().value(), // Enforced as non-optional in this case in IDL + contentionFactor(ep)); + toEncryptedBinData(fieldNameToSerialize, + EncryptedBinDataType::kFLE2InsertUpdatePayloadV2, + iupayload, + builder); + return; + } + + // TODO: SERVER-73303 delete when v2 is enabled by default auto iupayload = EDCClientPayload::serializeInsertUpdatePayloadForRange( indexKey, userKey, @@ -1425,25 +1636,38 @@ void parseAndVerifyInsertUpdatePayload(std::vector<EDCServerPayloadInfo>* pField StringData fieldPath, EncryptedBinDataType type, ConstDataRange subCdr) { - auto iupayload = EDCClientPayload::parseInsertUpdatePayload(subCdr); + EDCServerPayloadInfo payloadInfo; + payloadInfo.fieldPathName = fieldPath.toString(); + + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility)) { + uassert(7291901, + "Encountered a Queryable Encryption insert/update payload type that is no " + "longer supported", + type == EncryptedBinDataType::kFLE2InsertUpdatePayloadV2); + auto iupayload = EDCClientPayload::parseInsertUpdatePayloadV2(subCdr); + payloadInfo.payload = VersionedInsertUpdatePayload(std::move(iupayload)); + } else { + auto iupayload = EDCClientPayload::parseInsertUpdatePayload(subCdr); + payloadInfo.payload = VersionedInsertUpdatePayload(std::move(iupayload)); + } - bool isRangePayload = iupayload.getEdgeTokenSet().has_value(); + auto bsonType = static_cast<BSONType>(payloadInfo.payload.getType()); - if (isRangePayload) { + if (payloadInfo.isRangePayload()) { uassert(6775305, - str::stream() << "Type '" << typeName(static_cast<BSONType>(iupayload.getType())) + str::stream() << "Type '" << typeName(bsonType) << "' is not a valid type for Queryable Encryption Range", - isValidBSONType(iupayload.getType()) && - isFLE2RangeIndexedSupportedType(static_cast<BSONType>(iupayload.getType()))); + isValidBSONType(payloadInfo.payload.getType()) && + isFLE2RangeIndexedSupportedType(bsonType)); } else { uassert(6373504, - str::stream() << "Type '" << typeName(static_cast<BSONType>(iupayload.getType())) + str::stream() << "Type '" << typeName(bsonType) << "' is not a valid type for Queryable Encryption Equality", - isValidBSONType(iupayload.getType()) && - isFLE2EqualityIndexedSupportedType(static_cast<BSONType>(iupayload.getType()))); + isValidBSONType(payloadInfo.payload.getType()) && + isFLE2EqualityIndexedSupportedType(bsonType)); } - pFields->push_back({std::move(iupayload), fieldPath.toString(), {}}); + pFields->push_back(std::move(payloadInfo)); } void collectEDCServerInfo(std::vector<EDCServerPayloadInfo>* pFields, @@ -1455,13 +1679,16 @@ void collectEDCServerInfo(std::vector<EDCServerPayloadInfo>* pFields, auto [encryptedTypeBinding, subCdr] = fromEncryptedConstDataRange(cdr); auto encryptedType = encryptedTypeBinding; - if (encryptedType == EncryptedBinDataType::kFLE2InsertUpdatePayload) { + if (encryptedType == EncryptedBinDataType::kFLE2InsertUpdatePayload || + encryptedType == EncryptedBinDataType::kFLE2InsertUpdatePayloadV2) { parseAndVerifyInsertUpdatePayload(pFields, fieldPath, encryptedType, subCdr); return; - } else if (encryptedType == EncryptedBinDataType::kFLE2FindEqualityPayload) { + } else if (encryptedType == EncryptedBinDataType::kFLE2FindEqualityPayload || + encryptedType == EncryptedBinDataType::kFLE2FindEqualityPayloadV2) { // No-op return; - } else if (encryptedType == EncryptedBinDataType::kFLE2FindRangePayload) { + } else if (encryptedType == EncryptedBinDataType::kFLE2FindRangePayload || + encryptedType == EncryptedBinDataType::kFLE2FindRangePayloadV2) { // No-op return; } else if (encryptedType == EncryptedBinDataType::kFLE2UnindexedEncryptedValue) { @@ -1492,21 +1719,66 @@ void convertServerPayload(ConstDataRange cdr, StringData fieldPath) { auto [encryptedTypeBinding, subCdr] = fromEncryptedConstDataRange(cdr); if (encryptedTypeBinding == EncryptedBinDataType::kFLE2FindEqualityPayload || - encryptedTypeBinding == EncryptedBinDataType::kFLE2FindRangePayload) { + encryptedTypeBinding == EncryptedBinDataType::kFLE2FindRangePayload || + encryptedTypeBinding == EncryptedBinDataType::kFLE2FindEqualityPayloadV2 || + encryptedTypeBinding == EncryptedBinDataType::kFLE2FindRangePayloadV2) { builder->appendBinData(fieldPath, cdr.length(), BinDataType::Encrypt, cdr.data<char>()); return; - } else if (encryptedTypeBinding == EncryptedBinDataType::kFLE2InsertUpdatePayload) { - - if (it.it == it.end) { - return; - } + } else if (encryptedTypeBinding == EncryptedBinDataType::kFLE2InsertUpdatePayload || + encryptedTypeBinding == EncryptedBinDataType::kFLE2InsertUpdatePayloadV2) { + + // TODO: SERVER-73303 set to just kFLE2InsertUpdatePayloadV2 once is enabled by default + auto validVersionedTypeBinding = + gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility) + ? EncryptedBinDataType::kFLE2InsertUpdatePayloadV2 + : EncryptedBinDataType::kFLE2InsertUpdatePayload; + uassert(7291907, + "Encountered a Queryable Encryption insert/update payload type that is no longer " + "supported", + encryptedTypeBinding == validVersionedTypeBinding); uassert(6373505, "Unexpected end of iterator", it.it != it.end); const auto payload = it.it; // TODO - validate field is actually indexed in the schema? - if (payload->payload.getEdgeTokenSet().has_value()) { - FLE2IndexedRangeEncryptedValue sp(payload->payload, payload->counts); + if (payload->isRangePayload()) { + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled( + serverGlobalParams.featureCompatibility)) { + auto& v2Payload = payload->payload.getInsertUpdatePayloadVersion2(); + + FLE2IndexedRangeEncryptedValueV2 sp( + v2Payload, EDCServerCollection::generateTags(*payload), payload->counts); + + uassert(7291908, + str::stream() << "Type '" << typeName(sp.bsonType) + << "' is not a valid type for Queryable Encryption Range", + isFLE2RangeIndexedSupportedType(sp.bsonType)); + + std::vector<ServerDerivedFromDataToken> edgeDerivedTokens; + auto serverToken = FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>( + v2Payload.getServerEncryptionToken()); + for (auto& ets : v2Payload.getEdgeTokenSet().value()) { + edgeDerivedTokens.push_back( + FLETokenFromCDR<FLETokenType::ServerDerivedFromDataToken>( + ets.getServerDerivedFromDataToken())); + } + + auto swEncrypted = sp.serialize(serverToken, edgeDerivedTokens); + uassertStatusOK(swEncrypted); + toEncryptedBinData(fieldPath, + EncryptedBinDataType::kFLE2RangeIndexedValueV2, + ConstDataRange(swEncrypted.getValue()), + builder); + + for (auto& mblock : sp.metadataBlocks) { + pTags->push_back({mblock.tag}); + } + it.it++; + return; + } + // TODO: SERVER-73303 delete below once is enabled by default + auto& v1Payload = payload->payload.getInsertUpdatePayloadVersion1(); + FLE2IndexedRangeEncryptedValue sp(v1Payload, payload->counts); uassert(6775311, str::stream() << "Type '" << typeName(sp.bsonType) @@ -1530,7 +1802,36 @@ void convertServerPayload(ConstDataRange cdr, } else { dassert(payload->counts.size() == 1); - FLE2IndexedEqualityEncryptedValue sp(payload->payload, payload->counts[0]); + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled( + serverGlobalParams.featureCompatibility)) { + auto tag = EDCServerCollection::generateTag(*payload); + auto& v2Payload = payload->payload.getInsertUpdatePayloadVersion2(); + FLE2IndexedEqualityEncryptedValueV2 sp(v2Payload, tag, payload->counts[0]); + + uassert(7291906, + str::stream() << "Type '" << typeName(sp.bsonType) + << "' is not a valid type for Queryable Encryption Equality", + isFLE2EqualityIndexedSupportedType(sp.bsonType)); + + auto swEncrypted = + sp.serialize(FLETokenFromCDR<FLETokenType::ServerDataEncryptionLevel1Token>( + v2Payload.getServerEncryptionToken()), + FLETokenFromCDR<FLETokenType::ServerDerivedFromDataToken>( + v2Payload.getServerDerivedFromDataToken())); + uassertStatusOK(swEncrypted); + toEncryptedBinData(fieldPath, + EncryptedBinDataType::kFLE2EqualityIndexedValueV2, + ConstDataRange(swEncrypted.getValue()), + builder); + pTags->push_back({tag}); + + it.it++; + return; + } + + // TODO: SERVER-73303 delete below once v2 is enabled by default + auto& v1Payload = payload->payload.getInsertUpdatePayloadVersion1(); + FLE2IndexedEqualityEncryptedValue sp(v1Payload, payload->counts[0]); uassert(6373506, str::stream() << "Type '" << typeName(sp.bsonType) @@ -2176,6 +2477,61 @@ BSONObj FLEClientCrypto::generateCompactionTokens(const EncryptedFieldConfig& cf } BSONObj FLEClientCrypto::decryptDocument(BSONObj& doc, FLEKeyVault* keyVault) { + // TODO: SERVER-73851 remove once libmongocrypt supports parsing v2 payloads + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility)) { + BSONObjBuilder builder; + + auto obj = transformBSON( + doc, [keyVault](ConstDataRange cdr, BSONObjBuilder* builder, StringData fieldPath) { + auto [encryptedType, subCdr] = fromEncryptedConstDataRange(cdr); + if (encryptedType == EncryptedBinDataType::kFLE2EqualityIndexedValueV2 || + encryptedType == EncryptedBinDataType::kFLE2RangeIndexedValueV2) { + std::vector<uint8_t> userCipherText; + BSONType type; + if (encryptedType == EncryptedBinDataType::kFLE2EqualityIndexedValueV2) { + auto indexKeyId = + uassertStatusOK(FLE2IndexedEqualityEncryptedValueV2::readKeyId(subCdr)); + auto indexKey = keyVault->getIndexKeyById(indexKeyId); + auto serverToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token( + indexKey.key); + userCipherText = uassertStatusOK( + FLE2IndexedEqualityEncryptedValueV2::parseAndDecryptCiphertext( + serverToken, subCdr)); + type = uassertStatusOK( + FLE2IndexedEqualityEncryptedValueV2::readBsonType(subCdr)); + } else { + auto indexKeyId = + uassertStatusOK(FLE2IndexedRangeEncryptedValueV2::readKeyId(subCdr)); + auto indexKey = keyVault->getIndexKeyById(indexKeyId); + auto serverToken = + FLELevel1TokenGenerator::generateServerDataEncryptionLevel1Token( + indexKey.key); + userCipherText = uassertStatusOK( + FLE2IndexedRangeEncryptedValueV2::parseAndDecryptCiphertext(serverToken, + subCdr)); + type = + uassertStatusOK(FLE2IndexedRangeEncryptedValueV2::readBsonType(subCdr)); + } + + auto userKeyId = uassertStatusOK(KeyIdAndValue::readKeyId(userCipherText)); + auto userKey = keyVault->getUserKeyById(userKeyId); + auto userData = + uassertStatusOK(KeyIdAndValue::decrypt(userKey.key, userCipherText)); + BSONObj obj = toBSON(type, userData); + builder->appendAs(obj.firstElement(), fieldPath); + } else if (encryptedType == EncryptedBinDataType::kFLE2UnindexedEncryptedValue) { + auto [type, userData] = FLE2UnindexedEncryptedValue::deserialize(keyVault, cdr); + BSONObj obj = toBSON(type, userData); + builder->appendAs(obj.firstElement(), fieldPath); + } else { + builder->appendBinData( + fieldPath, cdr.length(), BinDataType::Encrypt, cdr.data()); + } + }); + builder.appendElements(obj); + return builder.obj(); + } auto crypt = createMongoCrypt(); @@ -3575,6 +3931,111 @@ StatusWith<std::vector<uint8_t>> FLE2IndexedRangeEncryptedValue::serialize( return serializedServerValue; } +VersionedEdgeTokenSet::VersionedEdgeTokenSet(EdgeTokenSet ets) : edgeTokenSet(std::move(ets)) {} + +VersionedEdgeTokenSet::VersionedEdgeTokenSet(EdgeTokenSetV2 ets) : edgeTokenSet(std::move(ets)) {} + +ConstDataRange VersionedEdgeTokenSet::getEscDerivedToken() const { + return stdx::visit( + OverloadedVisitor{[](const EdgeTokenSet& ets) { return ets.getEscDerivedToken(); }, + [](const EdgeTokenSetV2& ets) { + return ets.getEscDerivedToken(); + }}, + edgeTokenSet); +} + +ConstDataRange VersionedEdgeTokenSet::getEncryptedTokens() const { + return stdx::visit( + OverloadedVisitor{[](const EdgeTokenSet& ets) { return ets.getEncryptedTokens(); }, + [](const EdgeTokenSetV2& ets) { + return ets.getEncryptedTokens(); + }}, + edgeTokenSet); +} + +VersionedInsertUpdatePayload::VersionedInsertUpdatePayload(FLE2InsertUpdatePayload iup) + : iupayload(std::move(iup)), edgeTokenSet(convertPayloadEdgeTokenSet<decltype(iup)>()) {} + +VersionedInsertUpdatePayload::VersionedInsertUpdatePayload(FLE2InsertUpdatePayloadV2 iup) + : iupayload(std::move(iup)), edgeTokenSet(convertPayloadEdgeTokenSet<decltype(iup)>()) {} + +const FLE2InsertUpdatePayload& VersionedInsertUpdatePayload::getInsertUpdatePayloadVersion1() + const { + auto payloadPtr = stdx::get_if<FLE2InsertUpdatePayload>(&iupayload); + uassert( + 7291904, "Attempted to retrieve invalid version of FLE2InsertUpdatePayload", payloadPtr); + return *payloadPtr; +} + +const FLE2InsertUpdatePayloadV2& VersionedInsertUpdatePayload::getInsertUpdatePayloadVersion2() + const { + auto payloadPtr = stdx::get_if<FLE2InsertUpdatePayloadV2>(&iupayload); + uassert( + 7291905, "Attempted to retrieve invalid version of FLE2InsertUpdatePayload", payloadPtr); + return *payloadPtr; +} + +const mongo::UUID& VersionedInsertUpdatePayload::getIndexKeyId() const { + return stdx::visit( + OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) -> const mongo::UUID& { + return v1.getIndexKeyId(); + }, + [](const FLE2InsertUpdatePayloadV2& v2) -> const mongo::UUID& { + return v2.getIndexKeyId(); + }}, + iupayload); +} + +int VersionedInsertUpdatePayload::getType() const { + return stdx::visit( + OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) { return v1.getType(); }, + [](const FLE2InsertUpdatePayloadV2& v2) { + return v2.getType(); + }}, + iupayload); +} + +ConstDataRange VersionedInsertUpdatePayload::getEncryptedTokens() const { + return stdx::visit( + OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) { return v1.getEncryptedTokens(); }, + [](const FLE2InsertUpdatePayloadV2& v2) { + return v2.getEncryptedTokens(); + }}, + iupayload); +} + +ConstDataRange VersionedInsertUpdatePayload::getEscDerivedToken() const { + return stdx::visit( + OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) { return v1.getEscDerivedToken(); }, + [](const FLE2InsertUpdatePayloadV2& v2) { + return v2.getEscDerivedToken(); + }}, + iupayload); +} + +ConstDataRange VersionedInsertUpdatePayload::getEdcDerivedToken() const { + return stdx::visit( + OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) { return v1.getEdcDerivedToken(); }, + [](const FLE2InsertUpdatePayloadV2& v2) { + return v2.getEdcDerivedToken(); + }}, + iupayload); +} + +ConstDataRange VersionedInsertUpdatePayload::getServerEncryptionToken() const { + return stdx::visit(OverloadedVisitor{[](const FLE2InsertUpdatePayload& v1) { + return v1.getServerEncryptionToken(); + }, + [](const FLE2InsertUpdatePayloadV2& v2) { + return v2.getServerEncryptionToken(); + }}, + iupayload); +} + +const boost::optional<std::vector<VersionedEdgeTokenSet>>& +VersionedInsertUpdatePayload::getEdgeTokenSet() const { + return edgeTokenSet; +} FLE2IndexedRangeEncryptedValueV2::FLE2IndexedRangeEncryptedValueV2( const FLE2InsertUpdatePayloadV2& payload, @@ -3820,7 +4281,13 @@ void EDCServerCollection::validateEncryptedFieldInfo(BSONObj& obj, visitEncryptedBSON(obj, [&indexedFields](ConstDataRange cdr, StringData fieldPath) { auto [encryptedTypeBinding, subCdr] = fromEncryptedConstDataRange(cdr); - if (encryptedTypeBinding == EncryptedBinDataType::kFLE2InsertUpdatePayload) { + auto expectedPayloadType = EncryptedBinDataType::kFLE2InsertUpdatePayload; + + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility)) { + expectedPayloadType = EncryptedBinDataType::kFLE2InsertUpdatePayloadV2; + } + + if (encryptedTypeBinding == expectedPayloadType) { uassert(6373601, str::stream() << "Field '" << fieldPath << "' is encrypted, but absent from schema", @@ -3845,7 +4312,7 @@ std::vector<EDCServerPayloadInfo> EDCServerCollection::getEncryptedFieldInfo(BSO // We check here at runtime that all fields index keys are unique. stdx::unordered_set<UUID, UUID::Hash> indexKeyIds; for (const auto& field : fields) { - auto indexKeyId = field.payload.getIndexKeyId(); + auto& indexKeyId = field.payload.getIndexKeyId(); uassert(6371407, "Index key ids must be unique across fields in a document", !indexKeyIds.contains(indexKeyId)); @@ -3863,17 +4330,19 @@ PrfBlock EDCServerCollection::generateTag(const EDCServerPayloadInfo& payload) { auto token = FLETokenFromCDR<FLETokenType::EDCDerivedFromDataTokenAndContentionFactorToken>( payload.payload.getEdcDerivedToken()); auto edcTwiceDerived = FLETwiceDerivedTokenGenerator::generateEDCTwiceDerivedToken(token); - dassert(payload.payload.getEdgeTokenSet().has_value() == false); + dassert(payload.isRangePayload() == false); dassert(payload.counts.size() == 1); return generateTag(edcTwiceDerived, payload.counts[0]); } +// TODO: SERVER-73303 delete when v2 is enabled by default PrfBlock EDCServerCollection::generateTag(const FLE2IndexedEqualityEncryptedValue& indexedValue) { auto edcTwiceDerived = FLETwiceDerivedTokenGenerator::generateEDCTwiceDerivedToken(indexedValue.edc); return generateTag(edcTwiceDerived, indexedValue.count); } +// TODO: SERVER-73303 delete when v2 is enabled by default PrfBlock EDCServerCollection::generateTag(const FLEEdgeToken& token, FLECounter count) { auto edcTwiceDerived = FLETwiceDerivedTokenGenerator::generateEDCTwiceDerivedToken(token.edc); return generateTag(edcTwiceDerived, count); @@ -3896,6 +4365,29 @@ std::vector<PrfBlock> EDCServerCollection::generateTags( return tags; } +std::vector<PrfBlock> EDCServerCollection::generateTags(const EDCServerPayloadInfo& rangePayload) { + // throws if EDCServerPayloadInfo has invalid payload version + auto& v2Payload = rangePayload.payload.getInsertUpdatePayloadVersion2(); + + uassert(7291909, + "InsertUpdatePayload must have an edge token set", + v2Payload.getEdgeTokenSet().has_value()); + uassert(7291910, + "Mismatch between edge token set and counters lengths", + v2Payload.getEdgeTokenSet()->size() == rangePayload.counts.size()); + + auto& edgeTokenSets = v2Payload.getEdgeTokenSet().value(); + std::vector<PrfBlock> tags; + tags.reserve(edgeTokenSets.size()); + + for (size_t i = 0; i < edgeTokenSets.size(); i++) { + auto edcTwiceDerived = FLETwiceDerivedTokenGenerator::generateEDCTwiceDerivedToken( + FLETokenFromCDR<FLETokenType::EDCDerivedFromDataTokenAndContentionFactorToken>( + edgeTokenSets[i].getEdcDerivedToken())); + tags.push_back(EDCServerCollection::generateTag(edcTwiceDerived, rangePayload.counts[i])); + } + return tags; +} StatusWith<FLE2IndexedEqualityEncryptedValue> EDCServerCollection::decryptAndParse( ServerDataEncryptionLevel1Token token, ConstDataRange serializedServerValue) { diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h index 9bebdc1d4f6..f12f39c73de 100644 --- a/src/mongo/crypto/fle_crypto.h +++ b/src/mongo/crypto/fle_crypto.h @@ -1190,10 +1190,71 @@ struct FLE2IndexedRangeEncryptedValueV2 { std::vector<FLE2TagAndEncryptedMetadataBlock> metadataBlocks; }; +// TODO: SERVER-73303 delete when v2 is enabled by default +/* + * Shim layer for EdgeTokenSet types with different protocol versions. + */ +class VersionedEdgeTokenSet { +public: + VersionedEdgeTokenSet() = default; + VersionedEdgeTokenSet(EdgeTokenSet ets); + VersionedEdgeTokenSet(EdgeTokenSetV2 ets); + + ConstDataRange getEscDerivedToken() const; + ConstDataRange getEncryptedTokens() const; + +private: + stdx::variant<EdgeTokenSet, EdgeTokenSetV2> edgeTokenSet; +}; + +// TODO: SERVER-73303 delete when v2 is enabled by default +/* + * Shim layer for FLE2InsertUpdatePayload types with different protocol versions. + */ +class VersionedInsertUpdatePayload { +public: + VersionedInsertUpdatePayload() = default; + VersionedInsertUpdatePayload(FLE2InsertUpdatePayload iup); + VersionedInsertUpdatePayload(FLE2InsertUpdatePayloadV2 iup); + + const FLE2InsertUpdatePayload& getInsertUpdatePayloadVersion1() const; + const FLE2InsertUpdatePayloadV2& getInsertUpdatePayloadVersion2() const; + + const mongo::UUID& getIndexKeyId() const; + int getType() const; + ConstDataRange getEncryptedTokens() const; + ConstDataRange getEscDerivedToken() const; + ConstDataRange getEdcDerivedToken() const; + ConstDataRange getServerEncryptionToken() const; + const boost::optional<std::vector<VersionedEdgeTokenSet>>& getEdgeTokenSet() const; + +private: + template <class Payload> + boost::optional<std::vector<VersionedEdgeTokenSet>> convertPayloadEdgeTokenSet() { + boost::optional<std::vector<VersionedEdgeTokenSet>> converted; + auto& payload = stdx::get<Payload>(iupayload); + if (payload.getEdgeTokenSet().has_value()) { + auto& etsList = payload.getEdgeTokenSet().value(); + converted = std::vector<VersionedEdgeTokenSet>(etsList.size()); + std::transform(etsList.begin(), etsList.end(), edgeTokenSet->begin(), [](auto& ets) { + return VersionedEdgeTokenSet(ets); + }); + } + return converted; + } + stdx::variant<FLE2InsertUpdatePayload, FLE2InsertUpdatePayloadV2> iupayload; + boost::optional<std::vector<VersionedEdgeTokenSet>> edgeTokenSet; +}; + struct EDCServerPayloadInfo { static ESCDerivedFromDataTokenAndContentionFactorToken getESCToken(ConstDataRange cdr); - FLE2InsertUpdatePayload payload; + bool isRangePayload() const { + return payload.getEdgeTokenSet().has_value(); + } + + // TODO: SERVER-73303 change type to FLE2InsertUpdatePayloadV2 when v2 is enabled by default + VersionedInsertUpdatePayload payload; std::string fieldPathName; std::vector<uint64_t> counts; }; @@ -1266,6 +1327,7 @@ public: static PrfBlock generateTag(const FLE2IndexedEqualityEncryptedValue& indexedValue); static PrfBlock generateTag(const FLEEdgeToken& token, FLECounter count); static std::vector<PrfBlock> generateTags(const FLE2IndexedRangeEncryptedValue& indexedValue); + static std::vector<PrfBlock> generateTags(const EDCServerPayloadInfo& rangePayload); /** * Generate all the EDC tokens diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp index 3a4d0d1fb02..647537523f7 100644 --- a/src/mongo/db/fle_crud.cpp +++ b/src/mongo/db/fle_crud.cpp @@ -202,12 +202,12 @@ void validateInsertUpdatePayloads(const std::vector<EncryptedField>& fields, } for (const auto& field : payload) { - auto fieldPath = field.fieldPathName; + auto& fieldPath = field.fieldPathName; auto expect = pathToKeyIdMap.find(fieldPath); uassert(6726300, str::stream() << "Field '" << fieldPath << "' is unexpectedly encrypted", expect != pathToKeyIdMap.end()); - auto indexKeyId = field.payload.getIndexKeyId(); + auto& indexKeyId = field.payload.getIndexKeyId(); uassert(6726301, str::stream() << "Mismatched keyId for field '" << fieldPath << "' expected " << expect->second << ", found " << indexKeyId, @@ -532,12 +532,14 @@ write_ops::UpdateCommandReply processUpdate(OperationContext* opCtx, namespace { -void processFieldsForInsert(FLEQueryInterface* queryImpl, - const NamespaceString& edcNss, - std::vector<EDCServerPayloadInfo>& serverPayload, - const EncryptedFieldConfig& efc, - int32_t* pStmtId, - bool bypassDocumentValidation) { +// TODO: SERVER-73303 delete when v2 is enabled by default +void processFieldsForInsertV1(FLEQueryInterface* queryImpl, + const NamespaceString& edcNss, + std::vector<EDCServerPayloadInfo>& serverPayload, + const EncryptedFieldConfig& efc, + int32_t* pStmtId, + bool bypassDocumentValidation) { + const NamespaceString nssEsc(edcNss.dbName(), efc.getEscCollection().value()); auto docCount = queryImpl->countDocuments(nssEsc); @@ -626,6 +628,105 @@ void processFieldsForInsert(FLEQueryInterface* queryImpl, } } +void processFieldsForInsertV2(FLEQueryInterface* queryImpl, + const NamespaceString& edcNss, + std::vector<EDCServerPayloadInfo>& serverPayload, + const EncryptedFieldConfig& efc, + int32_t* pStmtId, + bool bypassDocumentValidation) { + + const NamespaceString nssEsc(edcNss.dbName(), efc.getEscCollection().value()); + + auto docCount = queryImpl->countDocuments(nssEsc); + + TxnCollectionReader reader(docCount, queryImpl, nssEsc); + + for (auto& payload : serverPayload) { + + const auto insertTokens = [&](ConstDataRange encryptedTokens, + ConstDataRange escDerivedToken) { + uint64_t count; + + auto escToken = EDCServerPayloadInfo::getESCToken(escDerivedToken); + auto tagToken = + FLETwiceDerivedTokenGenerator::generateESCTwiceDerivedTagToken(escToken); + auto valueToken = + FLETwiceDerivedTokenGenerator::generateESCTwiceDerivedValueToken(escToken); + + auto positions = ESCCollection::emuBinaryV2(reader, tagToken, valueToken); + + if (positions.cpos.has_value()) { + // Either no ESC documents exist yet (cpos == 0), OR new non-anchors + // have been inserted since the last compact/cleanup (cpos > 0). + count = positions.cpos.value() + 1; + } else { + // No new non-anchors since the last compact/cleanup. + // There must be at least one anchor. + uassert(7291902, + "An ESC anchor document is expected but none is found", + !positions.apos.has_value() || positions.apos.value() > 0); + + PrfBlock anchorId; + if (!positions.apos.has_value()) { + anchorId = ESCCollection::generateNullAnchorId(tagToken); + } else { + anchorId = ESCCollection::generateAnchorId(tagToken, positions.apos.value()); + } + + BSONObj anchorDoc = reader.getById(anchorId); + uassert(7291903, "ESC anchor document not found", !anchorDoc.isEmpty()); + + auto escAnchor = + uassertStatusOK(ESCCollection::decryptAnchorDocument(valueToken, anchorDoc)); + count = escAnchor.count + 1; + } + + payload.counts.push_back(count); + + auto escInsertReply = uassertStatusOK(queryImpl->insertDocument( + nssEsc, ESCCollection::generateNonAnchorDocument(tagToken, count), pStmtId, true)); + checkWriteErrors(escInsertReply); + + const NamespaceString nssEcoc(edcNss.dbName(), efc.getEcocCollection().value()); + + // TODO - should we make this a batch of ECOC updates? + auto ecocInsertReply = uassertStatusOK(queryImpl->insertDocument( + nssEcoc, + ECOCCollection::generateDocument(payload.fieldPathName, encryptedTokens), + pStmtId, + false, + bypassDocumentValidation)); + checkWriteErrors(ecocInsertReply); + }; + + payload.counts.clear(); + if (payload.payload.getEdgeTokenSet().has_value()) { + const auto& ets = payload.payload.getEdgeTokenSet().get(); + for (size_t i = 0; i < ets.size(); ++i) { + insertTokens(ets[i].getEncryptedTokens(), ets[i].getEscDerivedToken()); + } + } else { + insertTokens(payload.payload.getEncryptedTokens(), + payload.payload.getEscDerivedToken()); + } + } +} + +void processFieldsForInsert(FLEQueryInterface* queryImpl, + const NamespaceString& edcNss, + std::vector<EDCServerPayloadInfo>& serverPayload, + const EncryptedFieldConfig& efc, + int32_t* pStmtId, + bool bypassDocumentValidation) { + if (gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility)) { + processFieldsForInsertV2( + queryImpl, edcNss, serverPayload, efc, pStmtId, bypassDocumentValidation); + } else { + processFieldsForInsertV1( + queryImpl, edcNss, serverPayload, efc, pStmtId, bypassDocumentValidation); + } +} + void processRemovedFieldsHelper(FLEQueryInterface* queryImpl, const EncryptedFieldConfig& efc, const ESCDerivedFromDataTokenAndContentionFactorToken& esc, diff --git a/src/mongo/db/fle_crud_test.cpp b/src/mongo/db/fle_crud_test.cpp index 0f112b4f162..03b8f260354 100644 --- a/src/mongo/db/fle_crud_test.cpp +++ b/src/mongo/db/fle_crud_test.cpp @@ -207,8 +207,6 @@ protected: bool bypassDocumentValidation = false); void doSingleInsert(int id, BSONObj obj, bool bypassDocumentValidation = false); - void doSingleRangeInsert(int id, BSONElement element); - void doSingleInsertWithContention( int id, BSONElement element, int64_t cm, uint64_t cf, EncryptedFieldConfig efc); void doSingleInsertWithContention( @@ -779,6 +777,32 @@ TEST_F(FleCrudTest, InsertOne) { .isEmpty()); } +TEST_F(FleCrudTest, InsertOneV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + auto doc = BSON("encrypted" + << "secret"); + auto element = doc.firstElement(); + + doSingleInsert(1, element, Fle2AlgorithmInt::kEquality); + + assertDocumentCounts(1, 1, 0, 1); + assertECOCDocumentCountByField("encrypted", 1); + + ASSERT_FALSE( + _queryImpl->getById(_escNs, ESCCollection::generateNonAnchorId(getTestESCToken(element), 1)) + .isEmpty()); +} + +TEST_F(FleCrudTest, InsertOneRangeV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + auto doc = BSON("encrypted" << 5); + auto element = doc.firstElement(); + + doSingleInsert(1, element, Fle2AlgorithmInt::kRange); + assertDocumentCounts(1, 5, 0, 5); + assertECOCDocumentCountByField("encrypted", 5); +} + // Insert two documents with same values TEST_F(FleCrudTest, InsertTwoSame) { @@ -797,6 +821,25 @@ TEST_F(FleCrudTest, InsertTwoSame) { .isEmpty()); } +// Insert two documents with same values +TEST_F(FleCrudTest, InsertTwoSameV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + auto doc = BSON("encrypted" + << "secret"); + auto element = doc.firstElement(); + doSingleInsert(1, element, Fle2AlgorithmInt::kEquality); + doSingleInsert(2, element, Fle2AlgorithmInt::kEquality); + + assertDocumentCounts(2, 2, 0, 2); + assertECOCDocumentCountByField("encrypted", 2); + + auto escTagToken = getTestESCToken(element); + ASSERT_FALSE( + _queryImpl->getById(_escNs, ESCCollection::generateNonAnchorId(escTagToken, 1)).isEmpty()); + ASSERT_FALSE( + _queryImpl->getById(_escNs, ESCCollection::generateNonAnchorId(escTagToken, 2)).isEmpty()); +} + // Insert two documents with different values TEST_F(FleCrudTest, InsertTwoDifferent) { @@ -824,6 +867,34 @@ TEST_F(FleCrudTest, InsertTwoDifferent) { .isEmpty()); } +TEST_F(FleCrudTest, InsertTwoDifferentV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + doSingleInsert(1, + BSON("encrypted" + << "secret")); + doSingleInsert(2, + BSON("encrypted" + << "topsecret")); + + assertDocumentCounts(2, 2, 0, 2); + assertECOCDocumentCountByField("encrypted", 2); + + ASSERT_FALSE( + _queryImpl + ->getById(_escNs, + ESCCollection::generateNonAnchorId(getTestESCToken(BSON("encrypted" + << "secret")), + 1)) + .isEmpty()); + ASSERT_FALSE( + _queryImpl + ->getById(_escNs, + ESCCollection::generateNonAnchorId(getTestESCToken(BSON("encrypted" + << "topsecret")), + 1)) + .isEmpty()); +} + // Insert 1 document with 100 fields TEST_F(FleCrudTest, Insert100Fields) { @@ -850,6 +921,33 @@ TEST_F(FleCrudTest, Insert100Fields) { } } +// Insert 1 document with 100 fields +TEST_F(FleCrudTest, Insert100FieldsV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + + uint64_t fieldCount = 100; + ValueGenerator valueGenerator = [](StringData fieldName, uint64_t row) { + return fieldName.toString(); + }; + doSingleWideInsert(1, fieldCount, valueGenerator); + + assertDocumentCounts(1, fieldCount, 0, fieldCount); + + for (uint64_t field = 0; field < fieldCount; field++) { + auto fieldName = fieldNameFromInt(field); + + assertECOCDocumentCountByField(fieldName, 1); + + ASSERT_FALSE( + _queryImpl + ->getById( + _escNs, + ESCCollection::generateNonAnchorId( + getTestESCToken(fieldName, valueGenerator(fieldNameFromInt(field), 0)), 1)) + .isEmpty()); + } +} + // Insert 100 documents each with 20 fields with 7 distinct values per field TEST_F(FleCrudTest, Insert20Fields50Rows) { @@ -886,6 +984,83 @@ TEST_F(FleCrudTest, Insert20Fields50Rows) { } } +// Insert 100 documents each with 20 fields with 7 distinct values per field +TEST_F(FleCrudTest, Insert20Fields50RowsV2) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + uint64_t fieldCount = 20; + uint64_t rowCount = 50; + + ValueGenerator valueGenerator = [](StringData fieldName, uint64_t row) { + return fieldName.toString() + std::to_string(row % 7); + }; + + + for (uint64_t row = 0; row < rowCount; row++) { + doSingleWideInsert(row, fieldCount, valueGenerator); + } + + assertDocumentCounts(rowCount, rowCount * fieldCount, 0, rowCount * fieldCount); + + for (uint64_t row = 0; row < rowCount; row++) { + for (uint64_t field = 0; field < fieldCount; field++) { + auto fieldName = fieldNameFromInt(field); + + int count = (row / 7) + 1; + + assertECOCDocumentCountByField(fieldName, rowCount); + ASSERT_FALSE( + _queryImpl + ->getById(_escNs, + ESCCollection::generateNonAnchorId( + getTestESCToken(fieldName, + valueGenerator(fieldNameFromInt(field), row)), + count)) + .isEmpty()); + } + } +} + +// Test v1 FLE2InsertUpdatePayload is rejected if v2 is enabled. +// There are 2 places where the payload version compatibility is checked: +// 1. When parsing the InsertUpdatePayload in EDCServerCollection::getEncryptedFieldInfo() +// 2. When transforming the InsertUpdatePayload to the on-disk format in processInsert() +TEST_F(FleCrudTest, InsertV1PayloadAgainstV2Protocol) { + RAIIServerParameterControllerForTest controller("featureFlagFLE2ProtocolVersion2", true); + + std::vector<uint8_t> buf(64); + buf[0] = static_cast<uint8_t>(EncryptedBinDataType::kFLE2InsertUpdatePayload); + + BSONObjBuilder builder; + builder.append("_id", 1); + builder.append("counter", 1); + builder.append("plainText", "sample"); + builder.appendBinData("encrypted", buf.size(), BinDataType::Encrypt, buf.data()); + + BSONObj document = builder.obj(); + ASSERT_THROWS_CODE(EDCServerCollection::getEncryptedFieldInfo(document), DBException, 7291901); + + FLE2InsertUpdatePayloadV2 payload; + PrfBlock dummyToken; + payload.setEdcDerivedToken(dummyToken); + payload.setEscDerivedToken(dummyToken); + payload.setServerDerivedFromDataToken(dummyToken); + payload.setServerEncryptionToken(dummyToken); + payload.setEncryptedTokens(buf); + payload.setValue(buf); + payload.setType(BSONType::String); + + std::vector<EDCServerPayloadInfo> serverPayload(1); + serverPayload.front().fieldPathName = "encrypted"; + serverPayload.front().counts = {1}; + serverPayload.front().payload = std::move(payload); + + auto efc = getTestEncryptedFieldConfig(); + ASSERT_THROWS_CODE( + processInsert(_queryImpl.get(), _edcNs, serverPayload, efc, 0, document, false), + DBException, + 7291907); +} + #define ASSERT_ECC_DOC(assertElement, assertPosition, assertStart, assertEnd) \ { \ auto _eccDoc = getECCDocument(getTestECCToken((assertElement)), assertPosition); \ diff --git a/src/mongo/db/matcher/expression_type.h b/src/mongo/db/matcher/expression_type.h index 201c0799935..b45b0e90774 100644 --- a/src/mongo/db/matcher/expression_type.h +++ b/src/mongo/db/matcher/expression_type.h @@ -418,6 +418,8 @@ public: switch (subTypeByte) { case EncryptedBinDataType::kFLE2EqualityIndexedValue: case EncryptedBinDataType::kFLE2RangeIndexedValue: + case EncryptedBinDataType::kFLE2EqualityIndexedValueV2: + case EncryptedBinDataType::kFLE2RangeIndexedValueV2: case EncryptedBinDataType::kFLE2UnindexedEncryptedValue: { // Verify the type of the encrypted data. if (typeSet().isEmpty()) { diff --git a/src/mongo/db/matcher/expression_type_test.cpp b/src/mongo/db/matcher/expression_type_test.cpp index d89ab9bbc39..6404ff0e32f 100644 --- a/src/mongo/db/matcher/expression_type_test.cpp +++ b/src/mongo/db/matcher/expression_type_test.cpp @@ -387,6 +387,8 @@ TEST(InternalSchemaBinDataFLE2EncryptedTypeTest, MatchesOnlyFLE2ServerSubtypes) if (i == static_cast<uint8_t>(EncryptedBinDataType::kFLE2EqualityIndexedValue) || i == static_cast<uint8_t>(EncryptedBinDataType::kFLE2RangeIndexedValue) || + i == static_cast<uint8_t>(EncryptedBinDataType::kFLE2EqualityIndexedValueV2) || + i == static_cast<uint8_t>(EncryptedBinDataType::kFLE2RangeIndexedValueV2) || i == static_cast<uint8_t>(EncryptedBinDataType::kFLE2UnindexedEncryptedValue)) { ASSERT_TRUE(expr.matchesBSON(BSON("a" << binData))); } else { diff --git a/src/mongo/shell/shell_options.cpp b/src/mongo/shell/shell_options.cpp index 4060b157fd7..eeb06779abc 100644 --- a/src/mongo/shell/shell_options.cpp +++ b/src/mongo/shell/shell_options.cpp @@ -68,6 +68,7 @@ const std::set<std::string> kSetShellParameterAllowlist = { "awsECSInstanceMetadataUrl", "disabledSecureAllocatorDomains", "featureFlagFLE2Range", + "featureFlagFLE2ProtocolVersion2", "newLineAfterPasswordPromptForTest", "ocspClientHttpTimeoutSecs", "ocspEnabled", |