summaryrefslogtreecommitdiff
path: root/src/mongo/db/s/query_analysis_writer_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/s/query_analysis_writer_test.cpp')
-rw-r--r--src/mongo/db/s/query_analysis_writer_test.cpp1281
1 files changed, 1281 insertions, 0 deletions
diff --git a/src/mongo/db/s/query_analysis_writer_test.cpp b/src/mongo/db/s/query_analysis_writer_test.cpp
new file mode 100644
index 00000000000..b17f39c82df
--- /dev/null
+++ b/src/mongo/db/s/query_analysis_writer_test.cpp
@@ -0,0 +1,1281 @@
+/**
+ * Copyright (C) 2022-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/s/query_analysis_writer.h"
+
+#include "mongo/bson/unordered_fields_bsonobj_comparator.h"
+#include "mongo/db/db_raii.h"
+#include "mongo/db/dbdirectclient.h"
+#include "mongo/db/s/shard_server_test_fixture.h"
+#include "mongo/db/update/document_diff_calculator.h"
+#include "mongo/idl/server_parameter_test_util.h"
+#include "mongo/logv2/log.h"
+#include "mongo/s/analyze_shard_key_documents_gen.h"
+#include "mongo/unittest/death_test.h"
+#include "mongo/util/fail_point.h"
+
+#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
+
+namespace mongo {
+namespace analyze_shard_key {
+namespace {
+
+TEST(QueryAnalysisWriterBufferTest, AddBasic) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc0 = BSON("a" << 0);
+ buffer.add(doc0);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize());
+
+ auto doc1 = BSON("a" << BSON_ARRAY(0 << 1 << 2));
+ buffer.add(doc1);
+ ASSERT_EQ(buffer.getCount(), 2);
+ ASSERT_EQ(buffer.getSize(), doc0.objsize() + doc1.objsize());
+
+ ASSERT_BSONOBJ_EQ(buffer.at(0), doc0);
+ ASSERT_BSONOBJ_EQ(buffer.at(1), doc1);
+}
+
+TEST(QueryAnalysisWriterBufferTest, AddTooLarge) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 0);
+ ASSERT_EQ(buffer.getSize(), 0);
+}
+
+TEST(QueryAnalysisWriterBufferTest, TruncateBasic) {
+ auto testTruncateCommon = [](int oldCount, int newCount) {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ std::vector<BSONObj> docs;
+ for (auto i = 0; i < oldCount; i++) {
+ docs.push_back(BSON("a" << i));
+ }
+ // The documents have the same size.
+ auto docSize = docs.back().objsize();
+
+ for (const auto& doc : docs) {
+ buffer.add(doc);
+ }
+ ASSERT_EQ(buffer.getCount(), oldCount);
+ ASSERT_EQ(buffer.getSize(), oldCount * docSize);
+
+ buffer.truncate(newCount, (oldCount - newCount) * docSize);
+ ASSERT_EQ(buffer.getCount(), newCount);
+ ASSERT_EQ(buffer.getSize(), newCount * docSize);
+ for (auto i = 0; i < newCount; i++) {
+ ASSERT_BSONOBJ_EQ(buffer.at(i), docs[i]);
+ }
+ };
+
+ testTruncateCommon(10 /* oldCount */, 6 /* newCount */);
+ testTruncateCommon(10 /* oldCount */, 0 /* newCount */); // Truncate all.
+ testTruncateCommon(10 /* oldCount */, 9 /* newCount */); // Truncate one.
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(-1, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidIndex_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(2, doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Negative, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, -doc.objsize());
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Zero, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, 0);
+}
+
+DEATH_TEST(QueryAnalysisWriterBufferTest, TruncateInvalidSize_Positive, "invariant") {
+ auto buffer = QueryAnalysisWriter::Buffer();
+
+ auto doc = BSON("a" << 0);
+ buffer.add(doc);
+ ASSERT_EQ(buffer.getCount(), 1);
+ ASSERT_EQ(buffer.getSize(), doc.objsize());
+
+ buffer.truncate(0, doc.objsize() * 2);
+}
+
+void assertBsonObjEqualUnordered(const BSONObj& lhs, const BSONObj& rhs) {
+ UnorderedFieldsBSONObjComparator comparator;
+ ASSERT_EQ(comparator.compare(lhs, rhs), 0);
+}
+
+struct QueryAnalysisWriterTest : public ShardServerTestFixture {
+public:
+ void setUp() {
+ ShardServerTestFixture::setUp();
+ QueryAnalysisWriter::get(operationContext()).onStartup();
+
+ DBDirectClient client(operationContext());
+ client.createCollection(nss0.toString());
+ client.createCollection(nss1.toString());
+ }
+
+ void tearDown() {
+ QueryAnalysisWriter::get(operationContext()).onShutdown();
+ ShardServerTestFixture::tearDown();
+ }
+
+protected:
+ UUID getCollectionUUID(const NamespaceString& nss) const {
+ auto collectionCatalog = CollectionCatalog::get(operationContext());
+ return *collectionCatalog->lookupUUIDByNSS(operationContext(), nss);
+ }
+
+ BSONObj makeNonEmptyFilter() {
+ return BSON("_id" << UUID::gen());
+ }
+
+ BSONObj makeNonEmptyCollation() {
+ int strength = rand() % 5 + 1;
+ return BSON("locale"
+ << "en_US"
+ << "strength" << strength);
+ }
+
+ /*
+ * Makes an UpdateCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' updates and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the UpdateCommandRequest and the map storing the expected sampled
+ * UpdateCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::UpdateCommandRequest, std::map<UUID, write_ops::UpdateCommandRequest>>
+ makeUpdateCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::UpdateCommandRequest originalCmd(nss);
+ std::vector<write_ops::UpdateOpEntry> updateOps; // populated below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::UpdateCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto updateOp = write_ops::UpdateOpEntry(
+ BSON(filterFieldName << i),
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << i))));
+ updateOp.setC(BSON("x" << i));
+ updateOp.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << i))});
+ updateOp.setMulti(_getRandomBool());
+ updateOp.setUpsert(_getRandomBool());
+ updateOp.setUpsertSupplied(_getRandomBool());
+ updateOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ updateOp.setSampleId(sampleId);
+
+ write_ops::UpdateCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setUpdates({updateOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ updateOps.push_back(updateOp);
+ }
+ originalCmd.setUpdates(updateOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes an DeleteCommandRequest for the collection 'nss' such that the command contains
+ * 'numOps' deletes and the ones whose indices are in 'markForSampling' are marked for sampling.
+ * Returns the DeleteCommandRequest and the map storing the expected sampled
+ * DeleteCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::DeleteCommandRequest, std::map<UUID, write_ops::DeleteCommandRequest>>
+ makeDeleteCommandRequest(const NamespaceString& nss,
+ int numOps,
+ std::set<int> markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::DeleteCommandRequest originalCmd(nss);
+ std::vector<write_ops::DeleteOpEntry> deleteOps; // populated and set below.
+ originalCmd.setLet(let);
+ originalCmd.getWriteCommandRequestBase().setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::DeleteCommandRequest> expectedSampledCmds;
+
+ for (auto i = 0; i < numOps; i++) {
+ auto deleteOp =
+ write_ops::DeleteOpEntry(BSON(filterFieldName << i), _getRandomBool() /* multi */);
+ deleteOp.setCollation(makeNonEmptyCollation());
+
+ if (markForSampling.find(i) != markForSampling.end()) {
+ auto sampleId = UUID::gen();
+ deleteOp.setSampleId(sampleId);
+
+ write_ops::DeleteCommandRequest expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setDeletes({deleteOp});
+ expectedSampledCmd.getWriteCommandRequestBase().setEncryptionInformation(
+ boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+ deleteOps.push_back(deleteOp);
+ }
+ originalCmd.setDeletes(deleteOps);
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ /*
+ * Makes a FindAndModifyCommandRequest for the collection 'nss'. The findAndModify is an update
+ * if 'isUpdate' is true, and a remove otherwise. If 'markForSampling' is true, it is marked for
+ * sampling. Returns the FindAndModifyCommandRequest and the map storing the expected sampled
+ * FindAndModifyCommandRequests by sample id, if any.
+ */
+ std::pair<write_ops::FindAndModifyCommandRequest,
+ std::map<UUID, write_ops::FindAndModifyCommandRequest>>
+ makeFindAndModifyCommandRequest(const NamespaceString& nss,
+ bool isUpdate,
+ bool markForSampling,
+ std::string filterFieldName = "a") {
+ write_ops::FindAndModifyCommandRequest originalCmd(nss);
+ originalCmd.setQuery(BSON(filterFieldName << 0));
+ originalCmd.setUpdate(
+ write_ops::UpdateModification(BSON("$set" << BSON("b.$[element]" << 0))));
+ originalCmd.setArrayFilters(std::vector<BSONObj>{BSON("element" << BSON("$gt" << 10))});
+ originalCmd.setSort(BSON("_id" << 1));
+ if (isUpdate) {
+ originalCmd.setUpsert(_getRandomBool());
+ originalCmd.setNew(_getRandomBool());
+ }
+ originalCmd.setCollation(makeNonEmptyCollation());
+ originalCmd.setLet(let);
+ originalCmd.setEncryptionInformation(encryptionInformation);
+
+ std::map<UUID, write_ops::FindAndModifyCommandRequest> expectedSampledCmds;
+ if (markForSampling) {
+ auto sampleId = UUID::gen();
+ originalCmd.setSampleId(sampleId);
+
+ auto expectedSampledCmd = originalCmd;
+ expectedSampledCmd.setEncryptionInformation(boost::none);
+ expectedSampledCmds.emplace(sampleId, std::move(expectedSampledCmd));
+ }
+
+ return {originalCmd, expectedSampledCmds};
+ }
+
+ void deleteSampledQueryDocuments() const {
+ DBDirectClient client(operationContext());
+ client.remove(NamespaceString::kConfigSampledQueriesNamespace.toString(), BSONObj());
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueries
+ * collection.
+ */
+ int getSampledQueryDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled read query document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertSampledReadQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledReadCommandNameEnum cmdName,
+ const BSONObj& filter,
+ const BSONObj& collation) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledReadQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = SampledReadCommand::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.getFilter(), filter);
+ ASSERT_BSONOBJ_EQ(parsedCmd.getCollation(), collation);
+ }
+
+ /*
+ * Asserts that there is a sampled write query document with the given sample id and that it has
+ * the given fields.
+ */
+ template <typename CommandRequestType>
+ void assertSampledWriteQueryDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ SampledWriteCommandNameEnum cmdName,
+ const CommandRequestType& expectedCmd) {
+ auto doc = _getConfigDocument(NamespaceString::kConfigSampledQueriesNamespace, sampleId);
+ auto parsedQueryDoc =
+ SampledWriteQueryDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedQueryDoc.getNs(), nss);
+ ASSERT_EQ(parsedQueryDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedQueryDoc.getSampleId(), sampleId);
+ ASSERT(parsedQueryDoc.getCmdName() == cmdName);
+ auto parsedCmd = CommandRequestType::parse(IDLParserContext("QueryAnalysisWriterTest"),
+ parsedQueryDoc.getCmd());
+ ASSERT_BSONOBJ_EQ(parsedCmd.toBSON({}), expectedCmd.toBSON({}));
+ }
+
+ /*
+ * Returns the number of the documents for the collection 'nss' in the config.sampledQueriesDiff
+ * collection.
+ */
+ int getDiffDocumentsCount(const NamespaceString& nss) {
+ return _getConfigDocumentsCount(NamespaceString::kConfigSampledQueriesDiffNamespace, nss);
+ }
+
+ /*
+ * Asserts that there is a sampled diff document with the given sample id and that it has
+ * the given fields.
+ */
+ void assertDiffDocument(const UUID& sampleId,
+ const NamespaceString& nss,
+ const BSONObj& expectedDiff) {
+ auto doc =
+ _getConfigDocument(NamespaceString::kConfigSampledQueriesDiffNamespace, sampleId);
+ auto parsedDiffDoc =
+ SampledQueryDiffDocument::parse(IDLParserContext("QueryAnalysisWriterTest"), doc);
+
+ ASSERT_EQ(parsedDiffDoc.getNs(), nss);
+ ASSERT_EQ(parsedDiffDoc.getCollectionUuid(), getCollectionUUID(nss));
+ ASSERT_EQ(parsedDiffDoc.getSampleId(), sampleId);
+ assertBsonObjEqualUnordered(parsedDiffDoc.getDiff(), expectedDiff);
+ }
+
+ const NamespaceString nss0{"testDb", "testColl0"};
+ const NamespaceString nss1{"testDb", "testColl1"};
+
+ // Test with both empty and non-empty filter and collation to verify that the
+ // QueryAnalysisWriter doesn't require filter or collation to be non-empty.
+ const BSONObj emptyFilter{};
+ const BSONObj emptyCollation{};
+
+ const BSONObj let = BSON("x" << 1);
+ // Test with EncryptionInformation to verify that QueryAnalysisWriter does not persist the
+ // WriteCommandRequestBase fields, especially this sensitive field.
+ const EncryptionInformation encryptionInformation{BSON("foo"
+ << "bar")};
+
+private:
+ bool _getRandomBool() {
+ return rand() % 2 == 0;
+ }
+
+ /**
+ * Returns the number of the documents for the collection 'collNss' in the config collection
+ * 'configNss'.
+ */
+ int _getConfigDocumentsCount(const NamespaceString& configNss,
+ const NamespaceString& collNss) const {
+ DBDirectClient client(operationContext());
+ return client.count(configNss, BSON("ns" << collNss.toString()));
+ }
+
+ /**
+ * Returns the document with the given _id in the config collection 'configNss'.
+ */
+ BSONObj _getConfigDocument(const NamespaceString configNss, const UUID& id) const {
+ DBDirectClient client(operationContext());
+
+ FindCommandRequest findRequest{configNss};
+ findRequest.setFilter(BSON("_id" << id));
+ auto cursor = client.find(std::move(findRequest));
+ ASSERT(cursor->more());
+ return cursor->next();
+ }
+
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey", true};
+ FailPointEnableBlock _fp{"disableQueryAnalysisWriter"};
+};
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetIfFeatureFlagNotEnabled, "invariant") {
+ RAIIServerParameterControllerForTest _featureFlagController{"featureFlagAnalyzeShardKey",
+ false};
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnConfigServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::ConfigServer;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, CannotGetOnNonShardServer, "invariant") {
+ serverGlobalParams.clusterRole = ClusterRole::None;
+ QueryAnalysisWriter::get(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, NoQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, FindQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testFindCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addFindQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testFindCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testFindCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testFindCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testFindCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, CountQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testCountCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addCountQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kCount, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testCountCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testCountCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testCountCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testCountCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DistinctQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testDistinctCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addDistinctQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kDistinct, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testDistinctCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testDistinctCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testDistinctCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testDistinctCmdCommon(emptyFilter, emptyCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, AggregateQuery) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto testAggregateCmdCommon = [&](const BSONObj& filter, const BSONObj& collation) {
+ auto sampleId = UUID::gen();
+
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+
+ deleteSampledQueryDocuments();
+ };
+
+ testAggregateCmdCommon(makeNonEmptyFilter(), makeNonEmptyCollation());
+ testAggregateCmdCommon(makeNonEmptyFilter(), emptyCollation);
+ testAggregateCmdCommon(emptyFilter, makeNonEmptyCollation());
+ testAggregateCmdCommon(emptyFilter, emptyCollation);
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, UpdateQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeUpdateCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addUpdateQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, UpdateQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0, 3, {0, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, DeleteQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] = makeDeleteCommandRequest(nss0, 1, {} /* markForSampling */);
+ writer.addDeleteQuery(originalCmd, 0).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, DeleteQueriesMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0, 3, {1, 2} /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 1).get();
+ writer.addDeleteQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 2);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+DEATH_TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryNotMarkedForSampling, "invariant") {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ auto [originalCmd, _] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, false /* markForSampling */);
+ writer.addFindAndModifyQuery(originalCmd).get();
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryUpdateMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, true /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, FindAndModifyQueryRemoveMarkedForSampling) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto [originalCmd, expectedSampledCmds] =
+ makeFindAndModifyCommandRequest(nss0, false /* isUpdate */, true /* markForSampling */);
+ ASSERT_EQ(expectedSampledCmds.size(), 1U);
+ auto [sampleId, expectedSampledCmd] = *expectedSampledCmds.begin();
+
+ writer.addFindAndModifyQuery(originalCmd).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd);
+}
+
+TEST_F(QueryAnalysisWriterTest, MultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have one query.
+ auto [originalDeleteCmd, expectedSampledDeleteCmds] =
+ makeDeleteCommandRequest(nss1, 3, {1} /* markForSampling */);
+ ASSERT_EQ(expectedSampledDeleteCmds.size(), 1U);
+ auto [deleteSampleId, expectedSampledDeleteCmd] = *expectedSampledDeleteCmds.begin();
+
+ // Make nss1 have two queries.
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addDeleteQuery(originalDeleteCmd, 1).get();
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addCountQuery(countSampleId, nss1, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(deleteSampleId,
+ expectedSampledDeleteCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledDeleteCmd);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 2);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(countSampleId,
+ nss1,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateQueries) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto findSampleId = UUID::gen();
+ auto originalFindFilter = makeNonEmptyFilter();
+ auto originalFindCollation = makeNonEmptyCollation();
+
+ auto [originalUpdateCmd, expectedSampledUpdateCmds] =
+ makeUpdateCommandRequest(nss0, 1, {0} /* markForSampling */);
+ ASSERT_EQ(expectedSampledUpdateCmds.size(), 1U);
+ auto [updateSampleId, expectedSampledUpdateCmd] = *expectedSampledUpdateCmds.begin();
+
+ auto countSampleId = UUID::gen();
+ auto originalCountFilter = makeNonEmptyFilter();
+ auto originalCountCollation = makeNonEmptyCollation();
+
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+
+ writer.addUpdateQuery(originalUpdateCmd, 0).get();
+ writer.addFindQuery(findSampleId, nss0, originalFindFilter, originalFindCollation)
+ .get(); // This is a duplicate.
+ writer.addCountQuery(countSampleId, nss0, originalCountFilter, originalCountCollation).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 3);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 3);
+ assertSampledWriteQueryDocument(updateSampleId,
+ expectedSampledUpdateCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledUpdateCmd);
+ assertSampledReadQueryDocument(findSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kFind,
+ originalFindFilter,
+ originalFindCollation);
+ assertSampledReadQueryDocument(countSampleId,
+ nss0,
+ SampledReadCommandNameEnum::kCount,
+ originalCountFilter,
+ originalCountCollation);
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numQueries = 5;
+
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = makeNonEmptyFilter();
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, QueriesMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numQueries = 3;
+ std::vector<std::tuple<UUID, BSONObj, BSONObj>> expectedSampledCmds;
+ for (auto i = 0; i < numQueries; i++) {
+ auto sampleId = UUID::gen();
+ auto filter = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ auto collation = makeNonEmptyCollation();
+ writer.addAggregateQuery(sampleId, nss0, filter, collation).get();
+ expectedSampledCmds.push_back({sampleId, filter, collation});
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& [sampleId, filter, collation] : expectedSampledCmds) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kAggregate, filter, collation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddReadIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto sampleId0 = UUID::gen();
+ auto filter0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+ auto collation0 = makeNonEmptyCollation();
+
+ auto sampleId1 = UUID::gen();
+ auto filter1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+ auto collation1 = makeNonEmptyCollation();
+
+ writer.addFindQuery(sampleId0, nss0, filter0, collation0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addAggregateQuery(sampleId1, nss1, filter1, collation1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledReadQueryDocument(
+ sampleId0, nss0, SampledReadCommandNameEnum::kFind, filter0, collation0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledReadQueryDocument(
+ sampleId1, nss1, SampledReadCommandNameEnum::kAggregate, filter1, collation1);
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddUpdateIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeUpdateCommandRequest(nss0,
+ 3,
+ {0, 2} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addUpdateQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addUpdateQuery(originalCmd, 2).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kUpdate,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDeleteIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+ auto [originalCmd, expectedSampledCmds] =
+ makeDeleteCommandRequest(nss0,
+ 3,
+ {0, 1} /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds.size(), 2U);
+
+ writer.addDeleteQuery(originalCmd, 0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addDeleteQuery(originalCmd, 1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 2);
+ for (const auto& [sampleId, expectedSampledCmd] : expectedSampledCmds) {
+ assertSampledWriteQueryDocument(sampleId,
+ expectedSampledCmd.getNamespace(),
+ SampledWriteCommandNameEnum::kDelete,
+ expectedSampledCmd);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddFindAndModifyIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto [originalCmd0, expectedSampledCmds0] = makeFindAndModifyCommandRequest(
+ nss0,
+ true /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'a') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId0, expectedSampledCmd0] = *expectedSampledCmds0.begin();
+
+ auto [originalCmd1, expectedSampledCmds1] = makeFindAndModifyCommandRequest(
+ nss1,
+ false /* isUpdate */,
+ true /* markForSampling */,
+ std::string(maxMemoryUsageBytes / 2, 'b') /* filterFieldName */);
+ ASSERT_EQ(expectedSampledCmds0.size(), 1U);
+ auto [sampleId1, expectedSampledCmd1] = *expectedSampledCmds1.begin();
+
+ writer.addFindAndModifyQuery(originalCmd0).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 1);
+ // Adding the next query causes the size to exceed the limit.
+ writer.addFindAndModifyQuery(originalCmd1).get();
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 1);
+ assertSampledWriteQueryDocument(sampleId0,
+ expectedSampledCmd0.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd0);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), 1);
+ assertSampledWriteQueryDocument(sampleId1,
+ expectedSampledCmd1.getNamespace(),
+ SampledWriteCommandNameEnum::kFindAndModify,
+ expectedSampledCmd1);
+}
+
+TEST_F(QueryAnalysisWriterTest, AddQueriesBackAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+ auto numQueries = 8;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries);
+
+ // Force the documents to get inserted in three batches of size 3, 3 and 2, respectively.
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 3};
+
+ // Hang after inserting the documents in the first batch.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the second batch to fail so that it falls back to inserting one document at a time in
+ // order, and then force the first and second document in the batch to fail.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 3);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that all the documents other than the ones in the first batch got added back to the
+ // buffer after the error. That is, the error caused the last document in the second batch to
+ // get added to buffer also although it was successfully inserted since the writer did not have
+ // a way to tell if the error caused the entire command to fail early.
+ ASSERT_EQ(writer.getQueriesCountForTest(), 5);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), 4);
+
+ // Flush that remaining documents. If the documents were not added back correctly, some
+ // documents would be missing and the checks below would fail.
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, RemoveDuplicatesFromBufferAfterWriteError) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto originalFilter = makeNonEmptyFilter();
+ auto originalCollation = makeNonEmptyCollation();
+
+ auto numQueries0 = 3;
+
+ std::vector<UUID> sampleIds0;
+ for (auto i = 0; i < numQueries0; i++) {
+ sampleIds0.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss0), numQueries0);
+ for (const auto& sampleId : sampleIds0) {
+ assertSampledReadQueryDocument(
+ sampleId, nss0, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+
+ auto numQueries1 = 5;
+
+ std::vector<UUID> sampleIds1;
+ for (auto i = 0; i < numQueries1; i++) {
+ sampleIds1.push_back(UUID::gen());
+ writer.addFindQuery(sampleIds1[i], nss1, originalFilter, originalCollation).get();
+ // This is a duplicate.
+ if (i < numQueries0) {
+ writer.addFindQuery(sampleIds0[i], nss0, originalFilter, originalCollation).get();
+ }
+ }
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries0 + numQueries1);
+
+ // Force the batch to fail so that it falls back to inserting one document at a time in order.
+ auto failFp = globalFailPointRegistry().find("failCollectionInserts");
+ failFp->setMode(FailPoint::nTimes, 1);
+
+ // Hang after inserting the first non-duplicate document.
+ auto hangFp = globalFailPointRegistry().find("hangAfterCollectionInserts");
+ auto hangTimesEntered = hangFp->setMode(FailPoint::alwaysOn, 0);
+
+ auto future = stdx::async(stdx::launch::async, [&] {
+ ThreadClient tc(getServiceContext());
+ auto opCtx = makeOperationContext();
+ writer.flushQueriesForTest(opCtx.get());
+ });
+
+ hangFp->waitForTimesEntered(hangTimesEntered + 1);
+ // Force the next non-duplicate document to fail to insert.
+ failFp->setMode(FailPoint::nTimes, 1);
+ hangFp->setMode(FailPoint::off, 0);
+
+ future.get();
+ // Verify that the duplicate documents did not get added back to the buffer after the error.
+ ASSERT_EQ(writer.getQueriesCountForTest(), numQueries1);
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1 - 1);
+
+ // Flush that remaining documents. If the documents were not added back correctly, the document
+ // that previously failed to insert would be missing and the checks below would fail.
+ failFp->setMode(FailPoint::off, 0);
+ writer.flushQueriesForTest(operationContext());
+ ASSERT_EQ(writer.getQueriesCountForTest(), 0);
+
+ ASSERT_EQ(getSampledQueryDocumentsCount(nss1), numQueries1);
+ for (const auto& sampleId : sampleIds1) {
+ assertSampledReadQueryDocument(
+ sampleId, nss1, SampledReadCommandNameEnum::kFind, originalFilter, originalCollation);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, NoDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+ writer.flushQueriesForTest(operationContext());
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsBasic) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON("a" << 1);
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId, nss0, *doc_diff::computeInlineDiff(preImage, postImage));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleQueriesAndCollections) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ // Make nss0 have a diff for one query.
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0 << "b" << 0 << "c" << 0);
+ auto postImage0 = BSON("a" << 0 << "b" << 1 << "d" << 1);
+
+ // Make nss1 have diffs for two queries.
+ auto collUuid1 = getCollectionUUID(nss1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1) << "d" << BSON("e" << 1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2) << "d" << BSON("e" << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ writer.addDiff(sampleId2, nss1, collUuid1, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 2);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss1, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DuplicateDiffs) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON("a" << 1);
+
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1));
+ auto postImage1 = BSON("a" << 1 << "b" << BSON_ARRAY(1 << 2));
+
+ auto sampleId2 = UUID::gen();
+ auto preImage2 = BSON("a" << BSONObj());
+ auto postImage2 = BSON("a" << BSON("b" << 2));
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+
+ writer.addDiff(sampleId1, nss0, collUuid0, preImage1, postImage1).get();
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0)
+ .get(); // This is a duplicate.
+ writer.addDiff(sampleId2, nss0, collUuid0, preImage2, postImage2).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 3);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 3);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ assertDiffDocument(sampleId1, nss0, *doc_diff::computeInlineDiff(preImage1, postImage1));
+ assertDiffDocument(sampleId2, nss0, *doc_diff::computeInlineDiff(preImage2, postImage2));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBatchSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ RAIIServerParameterControllerForTest maxBatchSize{"queryAnalysisWriterMaxBatchSize", 2};
+ auto numDiffs = 5;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(("a" + std::to_string(i)) << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffsMultipleBatches_MaxBSONObjSize) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto numDiffs = 3;
+ auto collUuid0 = getCollectionUUID(nss0);
+
+ std::vector<std::pair<UUID, BSONObj>> expectedSampledDiffs;
+ for (auto i = 0; i < numDiffs; i++) {
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 0);
+ auto postImage = BSON(std::string(BSONObjMaxUserSize / 2, 'a') << 1);
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ expectedSampledDiffs.push_back(
+ {sampleId, *doc_diff::computeInlineDiff(preImage, postImage)});
+ }
+ ASSERT_EQ(writer.getDiffsCountForTest(), numDiffs);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), numDiffs);
+ for (const auto& [sampleId, diff] : expectedSampledDiffs) {
+ assertDiffDocument(sampleId, nss0, diff);
+ }
+}
+
+TEST_F(QueryAnalysisWriterTest, FlushAfterAddDiffIfExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto maxMemoryUsageBytes = 1024;
+ RAIIServerParameterControllerForTest maxMemoryBytes{"queryAnalysisWriterMaxMemoryUsageBytes",
+ maxMemoryUsageBytes};
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId0 = UUID::gen();
+ auto preImage0 = BSON("a" << 0);
+ auto postImage0 = BSON(std::string(maxMemoryUsageBytes / 2, 'a') << 1);
+
+ auto collUuid1 = getCollectionUUID(nss1);
+ auto sampleId1 = UUID::gen();
+ auto preImage1 = BSON("a" << 0);
+ auto postImage1 = BSON(std::string(maxMemoryUsageBytes / 2, 'b') << 1);
+
+ writer.addDiff(sampleId0, nss0, collUuid0, preImage0, postImage0).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 1);
+ // Adding the next diff causes the size to exceed the limit.
+ writer.addDiff(sampleId1, nss1, collUuid1, preImage1, postImage1).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 1);
+ assertDiffDocument(sampleId0, nss0, *doc_diff::computeInlineDiff(preImage0, postImage0));
+ ASSERT_EQ(getDiffDocumentsCount(nss1), 1);
+ assertDiffDocument(sampleId1, nss1, *doc_diff::computeInlineDiff(preImage1, postImage1));
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffEmpty) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON("a" << 1);
+ auto postImage = preImage;
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+TEST_F(QueryAnalysisWriterTest, DiffExceedsSizeLimit) {
+ auto& writer = QueryAnalysisWriter::get(operationContext());
+
+ auto collUuid0 = getCollectionUUID(nss0);
+ auto sampleId = UUID::gen();
+ auto preImage = BSON(std::string(BSONObjMaxUserSize, 'a') << 1);
+ auto postImage = BSONObj();
+
+ writer.addDiff(sampleId, nss0, collUuid0, preImage, postImage).get();
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+ writer.flushDiffsForTest(operationContext());
+ ASSERT_EQ(writer.getDiffsCountForTest(), 0);
+
+ ASSERT_EQ(getDiffDocumentsCount(nss0), 0);
+}
+
+} // namespace
+} // namespace analyze_shard_key
+} // namespace mongo