summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorMark Benvenuto <mark.benvenuto@mongodb.com>2023-03-14 16:34:58 -0400
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-03-15 14:19:40 +0000
commitb39c58a5326680ed06616359f69e654fdaf3b2ac (patch)
tree1bebab32b135f8ddc0cdd1339803318e9529cbd3 /src/mongo
parent971a3e0082b6469cade1a7e4e706bfe2545cce98 (diff)
downloadmongo-b39c58a5326680ed06616359f69e654fdaf3b2ac.tar.gz
SERVER-74151 Create a new reads tags command for QE tags
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/crypto/fle_crypto.cpp21
-rw-r--r--src/mongo/crypto/fle_crypto.h1
-rw-r--r--src/mongo/db/SConscript2
-rw-r--r--src/mongo/db/commands/SConscript5
-rw-r--r--src/mongo/db/commands/fle2_get_count_info_command.cpp178
-rw-r--r--src/mongo/db/commands/fle2_get_count_info_command.idl105
-rw-r--r--src/mongo/db/fle_crud.cpp101
-rw-r--r--src/mongo/db/fle_crud.h9
-rw-r--r--src/mongo/db/fle_crud_mongod.cpp69
-rw-r--r--src/mongo/db/fle_query_interface_mock.cpp6
-rw-r--r--src/mongo/s/commands/SConscript1
-rw-r--r--src/mongo/s/commands/cluster_fle2_get_count_info_cmd.cpp129
12 files changed, 605 insertions, 22 deletions
diff --git a/src/mongo/crypto/fle_crypto.cpp b/src/mongo/crypto/fle_crypto.cpp
index 259731c1043..77638d1bd52 100644
--- a/src/mongo/crypto/fle_crypto.cpp
+++ b/src/mongo/crypto/fle_crypto.cpp
@@ -450,13 +450,6 @@ void appendTag(PrfBlock block, BSONArrayBuilder* builder) {
builder->appendBinData(block.size(), BinDataType::BinDataGeneral, block.data());
}
-
-std::vector<uint8_t> vectorFromCDR(ConstDataRange cdr) {
- std::vector<uint8_t> buf(cdr.length());
- std::copy(cdr.data(), cdr.data() + cdr.length(), buf.data());
- return buf;
-}
-
template <typename T>
std::vector<uint8_t> toEncryptedVector(EncryptedBinDataType dt, T t) {
BSONObj obj = t.toBSON();
@@ -2401,6 +2394,12 @@ PrfBlock PrfBlockfromCDR(const ConstDataRange& block) {
return ret;
}
+std::vector<uint8_t> FLEUtil::vectorFromCDR(ConstDataRange cdr) {
+ std::vector<uint8_t> buf(cdr.length());
+ std::copy(cdr.data(), cdr.data() + cdr.length(), buf.data());
+ return buf;
+}
+
CollectionsLevel1Token FLELevel1TokenGenerator::generateCollectionsLevel1Token(
FLEIndexKey indexKey) {
return FLEUtil::prf(hmacKey(indexKey.data), kLevel1Collection);
@@ -3520,7 +3519,7 @@ FLE2IndexedEqualityEncryptedValue::FLE2IndexedEqualityEncryptedValue(
count(counter),
bsonType(static_cast<BSONType>(payload.getType())),
indexKeyId(payload.getIndexKeyId()),
- clientEncryptedValue(vectorFromCDR(payload.getValue())) {
+ clientEncryptedValue(FLEUtil::vectorFromCDR(payload.getValue())) {
uassert(6373508,
"Invalid BSON Type in Queryable Encryption InsertUpdatePayload",
isValidBSONType(payload.getType()));
@@ -3811,7 +3810,7 @@ FLE2IndexedEqualityEncryptedValueV2::FLE2IndexedEqualityEncryptedValueV2(
: FLE2IndexedEqualityEncryptedValueV2(
static_cast<BSONType>(payload.getType()),
payload.getIndexKeyId(),
- vectorFromCDR(payload.getValue()),
+ FLEUtil::vectorFromCDR(payload.getValue()),
FLE2TagAndEncryptedMetadataBlock(
counter, payload.getContentionFactor(), std::move(tag))) {}
@@ -4055,7 +4054,7 @@ FLE2IndexedRangeEncryptedValue::FLE2IndexedRangeEncryptedValue(FLE2InsertUpdateP
counters(std::move(countersParam)),
bsonType(static_cast<BSONType>(payload.getType())),
indexKeyId(payload.getIndexKeyId()),
- clientEncryptedValue(vectorFromCDR(payload.getValue())) {
+ clientEncryptedValue(FLEUtil::vectorFromCDR(payload.getValue())) {
uassert(6775312,
"Invalid BSON Type in Queryable Encryption InsertUpdatePayload",
isValidBSONType(payload.getType()));
@@ -4348,7 +4347,7 @@ FLE2IndexedRangeEncryptedValueV2::FLE2IndexedRangeEncryptedValueV2(
const std::vector<uint64_t>& counters)
: bsonType(static_cast<BSONType>(payload.getType())),
indexKeyId(payload.getIndexKeyId()),
- clientEncryptedValue(vectorFromCDR(payload.getValue())) {
+ clientEncryptedValue(FLEUtil::vectorFromCDR(payload.getValue())) {
uassert(7290900,
"Tags and counters parameters must be non-zero and of the same length",
diff --git a/src/mongo/crypto/fle_crypto.h b/src/mongo/crypto/fle_crypto.h
index 1409a4d50fe..64febeeb6cd 100644
--- a/src/mongo/crypto/fle_crypto.h
+++ b/src/mongo/crypto/fle_crypto.h
@@ -1824,6 +1824,7 @@ std::vector<std::string> minCoverDecimal128(Decimal128 lowerBound,
class FLEUtil {
public:
+ static std::vector<uint8_t> vectorFromCDR(ConstDataRange cdr);
static PrfBlock blockToArray(const SHA256Block& block);
/**
diff --git a/src/mongo/db/SConscript b/src/mongo/db/SConscript
index ea24dcdd773..1063f38a494 100644
--- a/src/mongo/db/SConscript
+++ b/src/mongo/db/SConscript
@@ -878,6 +878,7 @@ env.Library(
env.Library(
target='fle_crud',
source=[
+ 'commands/fle2_get_count_info_command.idl',
'fle_crud.cpp',
'query/fle/encrypted_predicate.cpp',
'query/fle/equality_predicate.cpp',
@@ -937,6 +938,7 @@ env.Library(
'$BUILD_DIR/mongo/db/ops/write_ops',
'$BUILD_DIR/mongo/db/repl/storage_interface_impl',
'fle_crud',
+ 'fle_crud_mongod',
],
)
diff --git a/src/mongo/db/commands/SConscript b/src/mongo/db/commands/SConscript
index b3e9df206ad..6adca1738ac 100644
--- a/src/mongo/db/commands/SConscript
+++ b/src/mongo/db/commands/SConscript
@@ -85,9 +85,12 @@ env.Library(
'$BUILD_DIR/mongo/db/session/logical_session_cache_impl',
'$BUILD_DIR/mongo/db/session/logical_session_id',
'$BUILD_DIR/mongo/db/session/logical_session_id_helpers',
+ '$BUILD_DIR/mongo/db/stats/counters',
+ '$BUILD_DIR/mongo/db/transaction/transaction_api',
'$BUILD_DIR/mongo/logv2/logv2_options',
'$BUILD_DIR/mongo/rpc/message',
'$BUILD_DIR/mongo/util/net/http_client',
+ 'server_status_core',
'test_commands_enabled',
],
)
@@ -345,6 +348,7 @@ env.Library(
"explain_cmd.cpp",
"find_and_modify.cpp",
"find_cmd.cpp",
+ 'fle2_get_count_info_command.cpp',
"getmore_cmd.cpp",
"http_client.cpp",
'http_client.idl',
@@ -389,6 +393,7 @@ env.Library(
'$BUILD_DIR/mongo/db/curop_failpoint_helpers',
'$BUILD_DIR/mongo/db/dbcommands_idl',
'$BUILD_DIR/mongo/db/exec/sbe/query_sbe_abt',
+ '$BUILD_DIR/mongo/db/fle_crud',
'$BUILD_DIR/mongo/db/fle_crud_mongod',
'$BUILD_DIR/mongo/db/index_builds_coordinator_interface',
'$BUILD_DIR/mongo/db/index_commands_idl',
diff --git a/src/mongo/db/commands/fle2_get_count_info_command.cpp b/src/mongo/db/commands/fle2_get_count_info_command.cpp
new file mode 100644
index 00000000000..1c30c4bb4cc
--- /dev/null
+++ b/src/mongo/db/commands/fle2_get_count_info_command.cpp
@@ -0,0 +1,178 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/bson/bsonobj.h"
+#include "mongo/crypto/fle_crypto.h"
+#include "mongo/db/auth/authorization_session.h"
+#include "mongo/db/commands.h"
+#include "mongo/db/commands/fle2_get_count_info_command_gen.h"
+#include "mongo/db/fle_crud.h"
+#include "mongo/db/operation_context.h"
+#include "mongo/util/assert_util.h"
+
+
+namespace mongo {
+namespace {
+
+std::vector<std::vector<FLEEdgePrfBlock>> toNestedTokens(
+ const std::vector<mongo::QECountInfoRequestTokenSet>& tagSets) {
+
+ std::vector<std::vector<FLEEdgePrfBlock>> nestedBlocks;
+ nestedBlocks.reserve(tagSets.size());
+
+ for (const auto& tagset : tagSets) {
+ std::vector<FLEEdgePrfBlock> blocks;
+
+ const auto& tags = tagset.getTokens();
+
+ blocks.reserve(tags.size());
+
+ for (auto& tag : tags) {
+ blocks.emplace_back();
+ auto& block = blocks.back();
+
+ block.esc = PrfBlockfromCDR(tag.getESCDerivedFromDataTokenAndContentionFactorToken());
+ block.edc =
+ tag.getEDCDerivedFromDataTokenAndContentionFactorToken().map(PrfBlockfromCDR);
+ }
+
+ nestedBlocks.emplace_back(std::move(blocks));
+ }
+
+ return nestedBlocks;
+}
+
+std::vector<QECountInfoReplyTokenSet> toGetTagRequestTupleSet(
+ const std::vector<std::vector<FLEEdgeCountInfo>>& countInfoSets) {
+
+ std::vector<QECountInfoReplyTokenSet> nestedBlocks;
+ nestedBlocks.reserve(countInfoSets.size());
+
+ for (const auto& countInfos : countInfoSets) {
+ std::vector<QECountInfoReplyTokens> tokens;
+
+ tokens.reserve(countInfos.size());
+
+ for (auto& countInfo : countInfos) {
+ tokens.emplace_back(FLEUtil::vectorFromCDR(countInfo.tagToken.toCDR()),
+ countInfo.count);
+
+ if (countInfo.edc.has_value()) {
+ auto& replyTuple = tokens.back();
+ replyTuple.setEDCDerivedFromDataTokenAndContentionFactorToken(
+ countInfo.edc.value().toCDR());
+ }
+ }
+
+ nestedBlocks.emplace_back(std::move(tokens));
+ }
+
+ return nestedBlocks;
+}
+
+QECountInfosReply getTagsLocal(OperationContext* opCtx,
+ const GetQueryableEncryptionCountInfo& request) {
+
+ uassert(741503,
+ "FeatureFlagFLE2ProtocolVersion2 is not enabled",
+ gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility));
+
+ auto nestedTokens = toNestedTokens(request.getTokens());
+
+ auto countInfoSets =
+ getTagsFromStorage(opCtx,
+ request.getNamespace(),
+ nestedTokens,
+ request.getForInsert() ? FLETagQueryInterface::TagQueryType::kInsert
+ : FLETagQueryInterface::TagQueryType::kQuery);
+
+ QECountInfosReply reply;
+ reply.setCounts(toGetTagRequestTupleSet(countInfoSets));
+
+ return reply;
+}
+
+/**
+ * Retrieve a set of tags from ESC. Returns a count suitable for either insert or query.
+ */
+class GetQueryableEncryptionCountInfoCmd final
+ : public TypedCommand<GetQueryableEncryptionCountInfoCmd> {
+public:
+ using Request = GetQueryableEncryptionCountInfo;
+ using Reply = GetQueryableEncryptionCountInfo::Reply;
+
+ class Invocation final : public InvocationBase {
+ public:
+ using InvocationBase::InvocationBase;
+
+ Reply typedRun(OperationContext* opCtx) {
+ return Reply(getTagsLocal(opCtx, request()));
+ }
+
+ private:
+ bool supportsWriteConcern() const final {
+ return false;
+ }
+
+ ReadConcernSupportResult supportsReadConcern(repl::ReadConcernLevel level,
+ bool isImplicitDefault) const final {
+ return ReadConcernSupportResult::allSupportedAndDefaultPermitted();
+ }
+
+ void doCheckAuthorization(OperationContext* opCtx) const final {
+ auto* as = AuthorizationSession::get(opCtx->getClient());
+ uassert(ErrorCodes::Unauthorized,
+ "Not authorized to read tags",
+ as->isAuthorizedForActionsOnResource(ResourcePattern::forClusterResource(),
+ ActionType::internal));
+ }
+
+ NamespaceString ns() const final {
+ return request().getNamespace();
+ }
+ };
+
+ AllowedOnSecondary secondaryAllowed(ServiceContext*) const final {
+ // Restrict to primary for now to allow future possibilities of caching on primary
+ return BasicCommand::AllowedOnSecondary::kNever;
+ }
+
+ bool adminOnly() const final {
+ return false;
+ }
+
+ bool allowedInTransactions() const final {
+ return true;
+ }
+} getQueryableEncryptionCountInfoCmd;
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/commands/fle2_get_count_info_command.idl b/src/mongo/db/commands/fle2_get_count_info_command.idl
new file mode 100644
index 00000000000..93a1a6df499
--- /dev/null
+++ b/src/mongo/db/commands/fle2_get_count_info_command.idl
@@ -0,0 +1,105 @@
+# Copyright (C) 2023-present MongoDB, Inc.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the Server Side Public License, version 1,
+# as published by MongoDB, Inc.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# Server Side Public License for more details.
+#
+# You should have received a copy of the Server Side Public License
+# along with this program. If not, see
+# <http://www.mongodb.com/licensing/server-side-public-license>.
+#
+# As a special exception, the copyright holders give permission to link the
+# code of portions of this program with the OpenSSL library under certain
+# conditions as described in each individual source file and distribute
+# linked combinations including the program with the OpenSSL library. You
+# must comply with the Server Side Public License in all respects for
+# all of the code used other than as permitted herein. If you modify file(s)
+# with this exception, you may extend this exception to your version of the
+# file(s), but you are not obligated to do so. If you do not wish to do so,
+# delete this exception statement from your version. If you delete this
+# exception statement from all source files in the program, then also delete
+# it in the license file.
+#
+
+global:
+ cpp_namespace: "mongo"
+
+imports:
+ - "mongo/db/basic_types.idl"
+
+structs:
+ QECountInfoReplyTokens:
+ description: "A tokens of ESC information"
+ strict: true
+ fields:
+ t:
+ description: "ESCTwiceDerivedTagToken"
+ type: bindata_generic
+ cpp_name: ESCTwiceDerivedTagToken
+ c:
+ description: "ESC count"
+ type: long
+ cpp_name: count
+ td:
+ description: "EDCDerivedFromDataTokenAndContentionFactorToken"
+ type: bindata_generic
+ cpp_name: EDCDerivedFromDataTokenAndContentionFactorToken
+ optional: true
+
+ QECountInfoReplyTokenSet:
+ description: "Array of tokens sets"
+ strict: true
+ fields:
+ tokens: array<QECountInfoReplyTokens>
+
+
+ QECountInfoRequestTokens:
+ description: "A ESC token to lookup in ESC"
+ strict: true
+ fields:
+ d:
+ description: "EDCDerivedFromDataTokenAndContentionFactorToken"
+ type: bindata_generic
+ cpp_name: EDCDerivedFromDataTokenAndContentionFactorToken
+ optional: true
+ s:
+ description: "ESCDerivedFromDataTokenAndContentionFactorToken"
+ type: bindata_generic
+ cpp_name: ESCDerivedFromDataTokenAndContentionFactorToken
+
+ QECountInfoRequestTokenSet:
+ description: An array of tag sets to lookup
+ strict: true
+ fields:
+ tokens: array<QECountInfoRequestTokens>
+
+
+ QECountInfosReply:
+ description: "Reply from the {getQueryableEncryptionCountInfo: ...} command"
+ # MongoS/Txn add fields to the reply we want to ignore
+ strict: false
+ is_command_reply: true
+ fields:
+ counts: array<QECountInfoReplyTokenSet>
+
+
+commands:
+ getQueryableEncryptionCountInfo:
+ description: "Parser for the 'getQueryableEncryptionCountInfo' command"
+ command_name: getQueryableEncryptionCountInfo
+ api_version: ""
+ namespace: concatenate_with_db
+ strict: true
+ reply_type: QECountInfosReply
+ fields:
+ tokens:
+ description: "Array of tokens to fetch"
+ type: array<QECountInfoRequestTokenSet>
+ forInsert:
+ description: Whether to return a count for insert or query
+ type: bool
diff --git a/src/mongo/db/fle_crud.cpp b/src/mongo/db/fle_crud.cpp
index 4c378bf9308..c44c8b801c4 100644
--- a/src/mongo/db/fle_crud.cpp
+++ b/src/mongo/db/fle_crud.cpp
@@ -40,6 +40,7 @@
#include "mongo/crypto/encryption_fields_gen.h"
#include "mongo/crypto/fle_crypto.h"
#include "mongo/db/auth/authorization_session.h"
+#include "mongo/db/commands/fle2_get_count_info_command_gen.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/fle_crud.h"
#include "mongo/db/namespace_string.h"
@@ -150,6 +151,73 @@ boost::optional<BSONObj> mergeLetAndCVariables(const boost::optional<BSONObj>& l
}
return c;
}
+
+template <FLETokenType TokenT>
+FLEToken<TokenT> FLETokenFromCDR(ConstDataRange cdr) {
+ auto block = PrfBlockfromCDR(cdr);
+ return FLEToken<TokenT>(block);
+}
+
+std::vector<QECountInfoRequestTokenSet> toTagSets(
+ const std::vector<std::vector<FLEEdgePrfBlock>>& blockSets) {
+
+ std::vector<QECountInfoRequestTokenSet> nestedBlocks;
+ nestedBlocks.reserve(blockSets.size());
+
+ for (const auto& tags : blockSets) {
+ std::vector<QECountInfoRequestTokens> tagsets;
+
+ tagsets.reserve(tags.size());
+
+ for (auto& tag : tags) {
+ tagsets.emplace_back(FLEUtil::vectorFromCDR(tag.esc));
+ auto& tokenSet = tagsets.back();
+
+ if (tag.edc.has_value()) {
+ tokenSet.setEDCDerivedFromDataTokenAndContentionFactorToken(
+ ConstDataRange(tag.edc.value()));
+ }
+ }
+
+ nestedBlocks.emplace_back();
+ nestedBlocks.back().setTokens(std::move(tagsets));
+ }
+
+ return nestedBlocks;
+}
+
+std::vector<std::vector<FLEEdgeCountInfo>> toEdgeCounts(
+ const std::vector<QECountInfoReplyTokenSet>& tupleSet) {
+
+ std::vector<std::vector<FLEEdgeCountInfo>> nestedBlocks;
+ nestedBlocks.reserve(tupleSet.size());
+
+ for (const auto& tuple : tupleSet) {
+ std::vector<FLEEdgeCountInfo> blocks;
+
+ const auto& tuples = tuple.getTokens();
+
+ blocks.reserve(tuples.size());
+
+ for (auto& tuple : tuples) {
+ blocks.emplace_back(tuple.getCount(),
+ FLETokenFromCDR<FLETokenType::ESCTwiceDerivedTagToken>(
+ tuple.getESCTwiceDerivedTagToken()));
+ auto& p = blocks.back();
+
+ if (tuple.getEDCDerivedFromDataTokenAndContentionFactorToken().has_value()) {
+ p.edc =
+ FLETokenFromCDR<FLETokenType::EDCDerivedFromDataTokenAndContentionFactorToken>(
+ tuple.getEDCDerivedFromDataTokenAndContentionFactorToken().value());
+ }
+ }
+
+ nestedBlocks.emplace_back(std::move(blocks));
+ }
+
+ return nestedBlocks;
+}
+
} // namespace
std::shared_ptr<txn_api::SyncTransactionWithRetries> getTransactionWithRetriesForMongoS(
@@ -638,12 +706,12 @@ void processFieldsForInsertV2(FLEQueryInterface* queryImpl,
int32_t* pStmtId,
bool bypassDocumentValidation) {
- const NamespaceString nssEsc(edcNss.dbName(), efc.getEscCollection().value());
-
if (serverPayload.empty()) {
return;
}
+ const NamespaceString nssEsc(edcNss.dbName(), efc.getEscCollection().value());
+
uint32_t totalTokens = 0;
std::vector<std::vector<FLEEdgePrfBlock>> tokensSets;
@@ -678,12 +746,20 @@ void processFieldsForInsertV2(FLEQueryInterface* queryImpl,
auto countInfoSets =
queryImpl->getTags(nssEsc, tokensSets, FLETagQueryInterface::TagQueryType::kInsert);
+ uassert(7415101,
+ "Mismatch in the number of expected tokens",
+ countInfoSets.size() == serverPayload.size());
+
std::vector<BSONObj> escDocuments;
escDocuments.reserve(totalTokens);
for (size_t i = 0; i < countInfoSets.size(); i++) {
auto& countInfos = countInfoSets[i];
+ uassert(7415104,
+ "Mismatch in the number of expected counts for a token",
+ countInfos.size() == tokensSets[i].size());
+
for (auto const& countInfo : countInfos) {
serverPayload[i].counts.push_back(countInfo.count);
@@ -1737,13 +1813,24 @@ std::vector<std::vector<FLEEdgeCountInfo>> FLEQueryInterfaceImpl::getTags(
const std::vector<std::vector<FLEEdgePrfBlock>>& tokensSets,
FLEQueryInterface::TagQueryType type) {
- auto docCount = countDocuments(nss);
+ GetQueryableEncryptionCountInfo getCountsCmd(nss);
- TxnCollectionReader reader(docCount, this, nss);
+ const auto tenantId = nss.tenantId();
+ if (tenantId && gMultitenancySupport) {
+ getCountsCmd.setDollarTenant(tenantId);
+ }
- return ESCCollection::getTags(reader, tokensSets, type);
-}
+ getCountsCmd.setTokens(toTagSets(tokensSets));
+ getCountsCmd.setForInsert(type == FLEQueryInterface::TagQueryType::kInsert);
+ auto response = _txnClient.runCommand(nss.db(), getCountsCmd.toBSON({})).get();
+ auto status = getStatusFromWriteCommandReply(response);
+ uassertStatusOK(status);
+
+ auto reply = QECountInfosReply::parse(IDLParserContext("reply"), response);
+
+ return toEdgeCounts(reply.getCounts());
+}
StatusWith<write_ops::InsertCommandReply> FLEQueryInterfaceImpl::insertDocuments(
const NamespaceString& nss,
@@ -1753,6 +1840,7 @@ StatusWith<write_ops::InsertCommandReply> FLEQueryInterfaceImpl::insertDocuments
bool bypassDocumentValidation) {
write_ops::InsertCommandRequest insertRequest(nss);
auto documentCount = objs.size();
+ dassert(documentCount > 0);
insertRequest.setDocuments(std::move(objs));
const auto tenantId = nss.tenantId();
@@ -1993,4 +2081,5 @@ std::unique_ptr<Pipeline, PipelineDeleter> processFLEPipelineS(
return fle::processPipeline(
opCtx, nss, encryptInfo, std::move(toRewrite), &getTransactionWithRetriesForMongoS);
}
+
} // namespace mongo
diff --git a/src/mongo/db/fle_crud.h b/src/mongo/db/fle_crud.h
index 45973a0e5c0..e954ff7c90c 100644
--- a/src/mongo/db/fle_crud.h
+++ b/src/mongo/db/fle_crud.h
@@ -490,4 +490,13 @@ write_ops::UpdateCommandReply processUpdate(OperationContext* opCtx,
void validateInsertUpdatePayloads(const std::vector<EncryptedField>& fields,
const std::vector<EDCServerPayloadInfo>& payload);
+/**
+ * Get the tags from local storage.
+ */
+std::vector<std::vector<FLEEdgeCountInfo>> getTagsFromStorage(
+ OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ const std::vector<std::vector<FLEEdgePrfBlock>>& escDerivedFromDataTokens,
+ FLETagQueryInterface::TagQueryType type);
+
} // namespace mongo
diff --git a/src/mongo/db/fle_crud_mongod.cpp b/src/mongo/db/fle_crud_mongod.cpp
index 2e18d831ec6..df9c1a2788b 100644
--- a/src/mongo/db/fle_crud_mongod.cpp
+++ b/src/mongo/db/fle_crud_mongod.cpp
@@ -46,6 +46,8 @@
#include "mongo/db/query/find_command_gen.h"
#include "mongo/db/query/fle/server_rewrite.h"
#include "mongo/db/repl/repl_client_info.h"
+#include "mongo/db/repl/replication_coordinator.h"
+#include "mongo/db/repl/storage_interface.h"
#include "mongo/db/session/session.h"
#include "mongo/db/session/session_catalog.h"
#include "mongo/db/session/session_catalog_mongod.h"
@@ -144,6 +146,53 @@ private:
bool _yielded = false;
};
+void toBinData(StringData field, PrfBlock block, BSONObjBuilder* builder) {
+ builder->appendBinData(field, block.size(), BinDataType::BinDataGeneral, block.data());
+}
+
+/**
+ * Read from ESC via StorageInterface
+ */
+class StorageInterfaceCollectionReader : public FLEStateCollectionReader {
+public:
+ StorageInterfaceCollectionReader(OperationContext* opCtx,
+ uint64_t count,
+ const NamespaceStringOrUUID& nsOrUUID,
+ repl::StorageInterface* storageInterface)
+ : _opCtx(opCtx), _count(count), _nssOrUUID(nsOrUUID), _storageInterface(storageInterface) {}
+
+ uint64_t getDocumentCount() const override {
+ return _count;
+ }
+
+ BSONObj getById(PrfBlock block) const override {
+
+ // Check for interruption so we can be killed
+ _opCtx->checkForInterrupt();
+
+ BSONObjBuilder builder;
+ toBinData("_id", block, &builder);
+ auto id = builder.obj();
+
+ auto swDoc = _storageInterface->findById(_opCtx, _nssOrUUID, id.firstElement());
+
+ if (swDoc.getStatus() == ErrorCodes::NoSuchKey ||
+ swDoc.getStatus() == ErrorCodes::NamespaceNotFound) {
+ return BSONObj();
+ }
+
+ uassertStatusOK(swDoc);
+
+ return swDoc.getValue();
+ }
+
+private:
+ OperationContext* _opCtx;
+ uint64_t _count;
+ const NamespaceStringOrUUID& _nssOrUUID;
+ repl::StorageInterface* _storageInterface;
+};
+
} // namespace
std::shared_ptr<txn_api::SyncTransactionWithRetries> getTransactionWithRetriesForMongoD(
@@ -290,4 +339,24 @@ processFLEFindAndModifyExplainMongod(OperationContext* opCtx,
opCtx, request, &getTransactionWithRetriesForMongoD, processFindAndModifyExplain));
}
+std::vector<std::vector<FLEEdgeCountInfo>> getTagsFromStorage(
+ OperationContext* opCtx,
+ const NamespaceStringOrUUID& nsOrUUID,
+ const std::vector<std::vector<FLEEdgePrfBlock>>& escDerivedFromDataTokens,
+ FLETagQueryInterface::TagQueryType type) {
+
+ auto storageInterface = repl::StorageInterface::get(opCtx);
+
+ auto swDocCount = storageInterface->getCollectionCount(opCtx, nsOrUUID);
+
+ uint64_t docCount = 0;
+ if (swDocCount.getStatus() != ErrorCodes::NamespaceNotFound) {
+ docCount = uassertStatusOK(swDocCount);
+ }
+
+ StorageInterfaceCollectionReader reader(opCtx, docCount, nsOrUUID, storageInterface);
+
+ return ESCCollection::getTags(reader, escDerivedFromDataTokens, type);
+}
+
} // namespace mongo
diff --git a/src/mongo/db/fle_query_interface_mock.cpp b/src/mongo/db/fle_query_interface_mock.cpp
index 51284c4257b..e54f18b7b15 100644
--- a/src/mongo/db/fle_query_interface_mock.cpp
+++ b/src/mongo/db/fle_query_interface_mock.cpp
@@ -58,11 +58,7 @@ std::vector<std::vector<FLEEdgeCountInfo>> FLEQueryInterfaceMock::getTags(
const std::vector<std::vector<FLEEdgePrfBlock>>& tokensSets,
FLETagQueryInterface::TagQueryType type) {
- auto docCount = countDocuments(nss);
-
- TxnCollectionReader reader(docCount, this, nss);
-
- return ESCCollection::getTags(reader, tokensSets, type);
+ return getTagsFromStorage(_opCtx, nss, tokensSets, type);
}
StatusWith<write_ops::InsertCommandReply> FLEQueryInterfaceMock::insertDocuments(
diff --git a/src/mongo/s/commands/SConscript b/src/mongo/s/commands/SConscript
index 1ce5225789f..3e3272329a3 100644
--- a/src/mongo/s/commands/SConscript
+++ b/src/mongo/s/commands/SConscript
@@ -59,6 +59,7 @@ env.Library(
'cluster_find_and_modify_cmd.cpp',
'cluster_find_cmd_s.cpp',
'cluster_fle2_compact_cmd.cpp',
+ 'cluster_fle2_get_count_info_cmd.cpp',
'cluster_fsync_cmd.cpp',
'cluster_ftdc_commands.cpp',
'cluster_get_cluster_parameter_cmd.cpp',
diff --git a/src/mongo/s/commands/cluster_fle2_get_count_info_cmd.cpp b/src/mongo/s/commands/cluster_fle2_get_count_info_cmd.cpp
new file mode 100644
index 00000000000..8d69a923af5
--- /dev/null
+++ b/src/mongo/s/commands/cluster_fle2_get_count_info_cmd.cpp
@@ -0,0 +1,129 @@
+/**
+ * Copyright (C) 2023-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/auth/authorization_session.h"
+#include "mongo/db/commands.h"
+#include "mongo/db/commands/fle2_get_count_info_command_gen.h"
+#include "mongo/s/cluster_commands_helpers.h"
+#include "mongo/s/grid.h"
+
+namespace mongo {
+namespace {
+
+/**
+ * Retrieve a set of tags from ESC. Returns a count suitable for either insert or query.
+ *
+ * Always routes to primary shard for a database because ESC is pinned to primary shard and ESC is
+ * not sharded.
+ */
+class ClusterGetQueryableEncryptionCountInfoCmd final
+ : public TypedCommand<ClusterGetQueryableEncryptionCountInfoCmd> {
+public:
+ using Request = GetQueryableEncryptionCountInfo;
+ using Reply = GetQueryableEncryptionCountInfo::Reply;
+
+ AllowedOnSecondary secondaryAllowed(ServiceContext*) const final {
+ return BasicCommand::AllowedOnSecondary::kAlways;
+ }
+
+ bool adminOnly() const final {
+ return false;
+ }
+
+ bool allowedInTransactions() const final {
+ return true;
+ }
+
+ class Invocation final : public InvocationBase {
+ public:
+ using InvocationBase::InvocationBase;
+
+ Reply typedRun(OperationContext* opCtx);
+
+ private:
+ bool supportsWriteConcern() const final {
+ return false;
+ }
+
+ ReadConcernSupportResult supportsReadConcern(repl::ReadConcernLevel level,
+ bool isImplicitDefault) const final {
+ return ReadConcernSupportResult::allSupportedAndDefaultPermitted();
+ }
+
+ void doCheckAuthorization(OperationContext* opCtx) const final {
+ auto* as = AuthorizationSession::get(opCtx->getClient());
+ uassert(ErrorCodes::Unauthorized,
+ "Not authorized to read tags",
+ as->isAuthorizedForActionsOnResource(ResourcePattern::forClusterResource(),
+ ActionType::internal));
+ }
+
+ NamespaceString ns() const final {
+ return request().getNamespace();
+ }
+ };
+
+} ClusterGetQueryableEncryptionCountInfoCmd;
+
+ClusterGetQueryableEncryptionCountInfoCmd::Reply
+ClusterGetQueryableEncryptionCountInfoCmd::Invocation::typedRun(OperationContext* opCtx) {
+
+ uassert(741502,
+ "FeatureFlagFLE2ProtocolVersion2 is not enabled",
+ gFeatureFlagFLE2ProtocolVersion2.isEnabled(serverGlobalParams.featureCompatibility));
+
+ auto nss = request().getNamespace();
+ const auto dbInfo =
+ uassertStatusOK(Grid::get(opCtx)->catalogCache()->getDatabase(opCtx, nss.db()));
+
+ auto response = uassertStatusOK(
+ executeCommandAgainstDatabasePrimary(
+ opCtx,
+ nss.db(),
+ dbInfo,
+ applyReadWriteConcern(
+ opCtx,
+ this,
+ CommandHelpers::filterCommandRequestForPassthrough(unparsedRequest().body)),
+ ReadPreferenceSetting(ReadPreference::PrimaryOnly),
+ Shard::RetryPolicy::kIdempotent)
+ .swResponse);
+
+ BSONObjBuilder result;
+ CommandHelpers::filterCommandReplyForPassthrough(response.data, &result);
+
+ auto reply = result.obj();
+ uassertStatusOK(getStatusFromCommandResult(reply));
+ return Reply::parse(IDLParserContext{Request::kCommandName}, reply.removeField("ok"_sd));
+}
+
+} // namespace
+} // namespace mongo