summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMihai Andrei <mihai.andrei@10gen.com>2021-07-21 22:49:44 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-07-21 23:26:46 +0000
commit97e7a659d01f8f5ceef69a4f738cdf76396d99db (patch)
treee451243f4e503de6edac72f000f1f4a500c3bc62
parent3aba3a2f4f7af9fe3a56fa6e6e6f80a922c56594 (diff)
downloadmongo-97e7a659d01f8f5ceef69a4f738cdf76396d99db.tar.gz
SERVER-57879 Implement $minN and $maxN accumulators
-rw-r--r--jstests/aggregation/accumulators/min_n_max_n.js159
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/accumulation_statement.h2
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.cpp211
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.h134
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp124
-rw-r--r--src/mongo/db/query/query_knobs.idl10
7 files changed, 633 insertions, 8 deletions
diff --git a/jstests/aggregation/accumulators/min_n_max_n.js b/jstests/aggregation/accumulators/min_n_max_n.js
new file mode 100644
index 00000000000..0394dcbb862
--- /dev/null
+++ b/jstests/aggregation/accumulators/min_n_max_n.js
@@ -0,0 +1,159 @@
+/**
+ * Basic tests for the $minN/$maxN accumulators.
+ */
+(function() {
+"use strict";
+
+const coll = db[jsTestName()];
+coll.drop();
+
+const isExactTopNEnabled = db.adminCommand({getParameter: 1, featureFlagExactTopNAccumulator: 1})
+ .featureFlagExactTopNAccumulator.value;
+
+if (!isExactTopNEnabled) {
+ // Verify that $minN/$maxN cannot be used if the feature flag is set to false and ignore the
+ // rest of the test.
+ assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{$group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales', n: 2}}}}],
+ cursor: {}
+ }),
+ 5787909);
+ return;
+}
+
+// Basic correctness tests.
+let docs = [];
+const n = 4;
+const states = [{state: 'CA', sales: 10}, {state: 'NY', sales: 7}, {state: 'TX', sales: 4}];
+let expectedMinNResults = [];
+let expectedMaxNResults = [];
+for (const stateDoc of states) {
+ const state = stateDoc['state'];
+ const sales = stateDoc['sales'];
+ let minArr = [];
+ let maxArr = [];
+ for (let i = 1; i <= sales; ++i) {
+ const amount = i * 100;
+ docs.push({state: state, sales: amount});
+
+ // Record the lowest/highest 'n' values.
+ if (i < n + 1) {
+ minArr.push(amount);
+ }
+ if (sales - n < i) {
+ maxArr.push(amount);
+ }
+ }
+ expectedMinNResults.push({_id: state, minSales: minArr});
+
+ // Reverse 'maxArr' results since $maxN outputs results in descending order.
+ expectedMaxNResults.push({_id: state, maxSales: maxArr.reverse()});
+}
+
+assert.commandWorked(coll.insert(docs));
+
+// Note that the output documents are sorted by '_id' so that we can compare actual groups against
+// expected groups (we cannot perform unordered comparison because order matters for $minN/$maxN).
+const actualMinNResults =
+ coll.aggregate([
+ {$group: {_id: '$state', minSales: {$minN: {output: '$sales', n: n}}}},
+ {$sort: {_id: 1}}
+ ])
+ .toArray();
+assert.eq(expectedMinNResults, actualMinNResults);
+
+const actualMaxNResults =
+ coll.aggregate([
+ {$group: {_id: '$state', maxSales: {$maxN: {output: '$sales', n: n}}}},
+ {$sort: {_id: 1}}
+ ])
+ .toArray();
+assert.eq(expectedMaxNResults, actualMaxNResults);
+
+// Verify that we can dynamically compute 'n' based on the group key for $group.
+const groupKeyNExpr = {
+ $cond: {if: {$eq: ['$st', 'CA']}, then: 10, else: 4}
+};
+const dynamicMinNResults =
+ coll.aggregate([{
+ $group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales', n: groupKeyNExpr}}}
+ }])
+ .toArray();
+
+// Verify that the 'CA' group has 10 results, while all others have only 4.
+for (const result of dynamicMinNResults) {
+ assert(result.hasOwnProperty('_id'), tojson(result));
+ const groupKey = result['_id'];
+ assert(groupKey.hasOwnProperty('st'), tojson(groupKey));
+ const state = groupKey['st'];
+ assert(result.hasOwnProperty('minSales'), tojson(result));
+ const salesArray = result['minSales'];
+ if (state === 'CA') {
+ assert.eq(salesArray.length, 10, tojson(salesArray));
+ } else {
+ assert.eq(salesArray.length, 4, tojson(salesArray));
+ }
+}
+
+// Error cases
+
+// Cannot reference the group key in $minN when using $bucketAuto.
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{
+ $bucketAuto: {
+ groupBy: "$state",
+ buckets: 2,
+ output: {minSales: {$minN: {output: '$sales', n: groupKeyNExpr}}}
+ }
+ }],
+ cursor: {}
+}),
+ 4544714);
+
+// Reject non-integral/negative values of n.
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline:
+ [{$group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales', n: 'string'}}}}],
+ cursor: {}
+}),
+ 5787902);
+
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{$group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales', n: 3.2}}}}],
+ cursor: {}
+}),
+ 5787903);
+
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{$group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales', n: -1}}}}],
+ cursor: {}
+}),
+ 5787908);
+
+// Reject invalid specifications.
+
+// Missing arguments.
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{$group: {_id: {'st': '$state'}, minSales: {$minN: {output: '$sales'}}}}],
+ cursor: {}
+}),
+ 5787906);
+
+assert.commandFailedWithCode(
+ coll.runCommand(
+ "aggregate",
+ {pipeline: [{$group: {_id: {'st': '$state'}, minSales: {$minN: {n: 2}}}}], cursor: {}}),
+ 5787907);
+
+// Extra field.
+assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{
+ $group: {
+ _id: {'st': '$state'},
+ minSales: {$minN: {output: '$sales', n: 2, randomField: "randomArg"}}
+ }
+ }],
+ cursor: {}
+}),
+ 5787901);
+})();
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index f0c66422019..b4ff7d76fcb 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -131,6 +131,7 @@ env.Library(
'accumulator_last.cpp',
'accumulator_merge_objects.cpp',
'accumulator_min_max.cpp',
+ 'accumulator_multi.cpp',
'accumulator_push.cpp',
'accumulator_rank.cpp',
'accumulator_std_dev.cpp',
diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h
index 9ee7adb859f..d20bc1bb8a7 100644
--- a/src/mongo/db/pipeline/accumulation_statement.h
+++ b/src/mongo/db/pipeline/accumulation_statement.h
@@ -76,7 +76,7 @@ namespace mongo {
* contains intermediate values being accumulated.
*
* Like most accumulators, $sum does not require or accept an initializer Expression. At time of
- * writing, only user-defined accumulators accept an initializer.
+ * writing, only user-defined accumulators and the 'N' family of accumulators accept an initializer.
*
* For example, in:
* {$group: {
diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp
new file mode 100644
index 00000000000..b3cbf26c924
--- /dev/null
+++ b/src/mongo/db/pipeline/accumulator_multi.cpp
@@ -0,0 +1,211 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/pipeline/accumulator_multi.h"
+
+namespace mongo {
+// TODO SERVER-58379 Update these macros once FCV constants are upgraded.
+REGISTER_ACCUMULATOR(maxN, AccumulatorMinMaxN::parseMinMaxN<Sense::kMax>);
+REGISTER_ACCUMULATOR(minN, AccumulatorMinMaxN::parseMinMaxN<Sense::kMin>);
+// TODO SERVER-57882 Add $minN/$maxN as expressions.
+// TODO SERVER-57885 Add $minN/$maxN as window functions.
+
+AccumulatorN::AccumulatorN(ExpressionContext* const expCtx)
+ : AccumulatorState(expCtx), _maxMemUsageBytes(internalQueryMaxNAccumulatorBytes.load()) {}
+
+void AccumulatorN::startNewGroup(const Value& input) {
+ // Obtain the value for 'n' and error if it's not a positive integral.
+ uassert(5787902,
+ str::stream() << "Value for 'n' must be of integral type, but found "
+ << input.toString(),
+ isNumericBSONType(input.getType()));
+ auto n = input.coerceToLong();
+ uassert(5787903,
+ str::stream() << "Value for 'n' must be of integral type, but found "
+ << input.toString(),
+ n == input.coerceToDouble());
+ uassert(5787908, str::stream() << "'n' must be greater than 0, found " << n, n > 0);
+ _n = n;
+}
+
+AccumulatorMinMaxN::AccumulatorMinMaxN(ExpressionContext* const expCtx, Sense sense)
+ : AccumulatorN(expCtx),
+ _set(expCtx->getValueComparator().makeOrderedValueMultiset()),
+ _sense(sense) {
+ _memUsageBytes = sizeof(*this);
+}
+
+const char* AccumulatorMinMaxN::getOpName() const {
+ if (_sense == Sense::kMin) {
+ return AccumulatorMinN::getName();
+ } else {
+ return AccumulatorMaxN::getName();
+ }
+}
+
+Document AccumulatorMinMaxN::serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ bool explain) const {
+ MutableDocument args;
+ AccumulatorN::serializeHelper(initializer, argument, explain, args);
+ return DOC(getOpName() << args.freeze());
+}
+
+std::tuple<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>>
+AccumulatorN::parseArgs(ExpressionContext* const expCtx,
+ const BSONObj& args,
+ VariablesParseState vps) {
+ boost::intrusive_ptr<Expression> n;
+ boost::intrusive_ptr<Expression> output;
+ for (auto&& element : args) {
+ auto fieldName = element.fieldNameStringData();
+ if (fieldName == kFieldNameOutput) {
+ output = Expression::parseOperand(expCtx, element, vps);
+ } else if (fieldName == kFieldNameN) {
+ n = Expression::parseOperand(expCtx, element, vps);
+ } else {
+ uasserted(5787901, str::stream() << "Unknown argument to minN/maxN: " << fieldName);
+ }
+ }
+ uassert(5787906, "Missing value for 'n'", n);
+ uassert(5787907, "Missing value for 'output'", output);
+ return std::make_tuple(n, output);
+}
+
+void AccumulatorN::serializeHelper(const boost::intrusive_ptr<Expression>& initializer,
+ const boost::intrusive_ptr<Expression>& argument,
+ bool explain,
+ MutableDocument& md) {
+ md.addField(kFieldNameN, Value(initializer->serialize(explain)));
+ md.addField(kFieldNameOutput, Value(argument->serialize(explain)));
+}
+
+template <Sense s>
+AccumulationExpression AccumulatorMinMaxN::parseMinMaxN(ExpressionContext* const expCtx,
+ BSONElement elem,
+ VariablesParseState vps) {
+ auto name = [] {
+ if constexpr (s == Sense::kMin) {
+ return AccumulatorMinN::getName();
+ } else {
+ return AccumulatorMaxN::getName();
+ }
+ }();
+
+ // TODO SERVER-58379 Remove this uassert once the FCV constants are upgraded and the REGISTER
+ // macros above are updated accordingly.
+ uassert(5787909,
+ str::stream() << "Cannot create " << name << " accumulator if feature flag is disabled",
+ feature_flags::gFeatureFlagExactTopNAccumulator.isEnabledAndIgnoreFCV());
+ uassert(5787900,
+ str::stream() << "specification must be an object; found " << elem,
+ elem.type() == BSONType::Object);
+ BSONObj obj = elem.embeddedObject();
+
+ auto [n, output] = AccumulatorN::parseArgs(expCtx, obj, vps);
+
+ auto factory = [expCtx] {
+ if constexpr (s == Sense::kMin) {
+ return AccumulatorMinN::create(expCtx);
+ } else {
+ return AccumulatorMaxN::create(expCtx);
+ }
+ };
+
+ return {std::move(n), std::move(output), std::move(factory), name};
+}
+
+void AccumulatorMinMaxN::processValue(const Value& val) {
+ // Ignore nullish values.
+ if (val.nullish())
+ return;
+
+ // Only compare if we have 'n' elements.
+ if (static_cast<long long>(_set.size()) == *_n) {
+ // Get an iterator to the element we want to compare against.
+ auto cmpElem = _sense == Sense::kMin ? std::prev(_set.end()) : _set.begin();
+
+ auto cmp = getExpressionContext()->getValueComparator().compare(*cmpElem, val) * _sense;
+ if (cmp > 0) {
+ _memUsageBytes -= cmpElem->getApproximateSize();
+ _set.erase(cmpElem);
+ } else {
+ return;
+ }
+ }
+ _memUsageBytes += val.getApproximateSize();
+ uassert(ErrorCodes::ExceededMemoryLimit,
+ str::stream() << getOpName()
+ << " used too much memory and cannot spill to disk. Memory limit: "
+ << _maxMemUsageBytes << " bytes",
+ _memUsageBytes < _maxMemUsageBytes);
+ _set.emplace(val);
+}
+
+void AccumulatorMinMaxN::processInternal(const Value& input, bool merging) {
+ tassert(5787904, "'n' must be initialized", _n);
+
+ if (merging) {
+ tassert(5787905, "input must be an array when 'merging' is true", input.isArray());
+ auto array = input.getArray();
+ for (auto&& val : array) {
+ processValue(val);
+ }
+ } else {
+ processValue(input);
+ }
+}
+
+Value AccumulatorMinMaxN::getValue(bool toBeMerged) {
+ // Return the values in ascending order for 'kMin' and descending order for 'kMax'.
+ return Value(_sense == Sense::kMin ? std::vector<Value>(_set.begin(), _set.end())
+ : std::vector<Value>(_set.rbegin(), _set.rend()));
+}
+
+void AccumulatorMinMaxN::reset() {
+ _set = getExpressionContext()->getValueComparator().makeOrderedValueMultiset();
+ _memUsageBytes = sizeof(*this);
+}
+
+const char* AccumulatorMinN::getName() {
+ return kName.rawData();
+}
+
+boost::intrusive_ptr<AccumulatorState> AccumulatorMinN::create(ExpressionContext* const expCtx) {
+ return make_intrusive<AccumulatorMinN>(expCtx);
+}
+
+const char* AccumulatorMaxN::getName() {
+ return kName.rawData();
+}
+
+boost::intrusive_ptr<AccumulatorState> AccumulatorMaxN::create(ExpressionContext* const expCtx) {
+ return make_intrusive<AccumulatorMaxN>(expCtx);
+}
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_multi.h b/src/mongo/db/pipeline/accumulator_multi.h
new file mode 100644
index 00000000000..aac2dc66c15
--- /dev/null
+++ b/src/mongo/db/pipeline/accumulator_multi.h
@@ -0,0 +1,134 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/pipeline/accumulation_statement.h"
+
+namespace mongo {
+using Sense = AccumulatorMinMax::Sense;
+
+/**
+ * An AccumulatorN picks 'n' of its input values and returns them in an array. Each derived class
+ * has different criteria for how to pick values and order the final array, but any common behavior
+ * shared by derived classes is implemented in this class. In particular:
+ * - Initializing 'n' during 'startNewGroup'.
+ * - Parsing the expressions for 'n' and 'output'.
+ */
+class AccumulatorN : public AccumulatorState {
+public:
+ AccumulatorN(ExpressionContext* const expCtx);
+
+protected:
+ // Initialize 'n' with 'input'. In particular, verifies that 'input' is a positive integer.
+ void startNewGroup(const Value& input) final;
+
+ // Parses 'args' for the 'n' and 'output' arguments that are common to the 'N' family of
+ // accumulators.
+ static std::tuple<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>> parseArgs(
+ ExpressionContext* const expCtx, const BSONObj& args, VariablesParseState vps);
+
+ // Helper which appends the 'n' and 'output' fields to 'md'.
+ static void serializeHelper(const boost::intrusive_ptr<Expression>& initializer,
+ const boost::intrusive_ptr<Expression>& argument,
+ bool explain,
+ MutableDocument& md);
+
+ // Stores the limit of how many values we will return. This value is initialized to
+ // 'boost::none' on construction and is only set during 'startNewGroup'.
+ boost::optional<long long> _n;
+
+ int _maxMemUsageBytes = 0;
+
+private:
+ static constexpr auto kFieldNameN = "n"_sd;
+ static constexpr auto kFieldNameOutput = "output"_sd;
+};
+class AccumulatorMinMaxN : public AccumulatorN {
+public:
+ AccumulatorMinMaxN(ExpressionContext* const expCtx, Sense sense);
+
+ /**
+ * Verifies that 'elem' is an object, delegates argument parsing to 'AccumulatorN::parseArgs',
+ * and constructs an AccumulationExpression representing $minN or $maxN depending on 's'.
+ */
+ template <Sense s>
+ static AccumulationExpression parseMinMaxN(ExpressionContext* const expCtx,
+ BSONElement elem,
+ VariablesParseState vps);
+
+ void processInternal(const Value& input, bool merging) final;
+
+ Value getValue(bool toBeMerged) final;
+
+ const char* getOpName() const final;
+
+ Document serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ bool explain) const final;
+
+ void reset() final;
+
+ bool isAssociative() const final {
+ return true;
+ }
+
+ bool isCommutative() const final {
+ return true;
+ }
+
+private:
+ void processValue(const Value& val);
+
+ ValueMultiset _set;
+ Sense _sense;
+};
+
+class AccumulatorMinN : public AccumulatorMinMaxN {
+public:
+ static constexpr auto kName = "$minN"_sd;
+ explicit AccumulatorMinN(ExpressionContext* const expCtx)
+ : AccumulatorMinMaxN(expCtx, Sense::kMin) {}
+
+ static const char* getName();
+
+ static boost::intrusive_ptr<AccumulatorState> create(ExpressionContext* const expCtx);
+};
+
+class AccumulatorMaxN : public AccumulatorMinMaxN {
+public:
+ static constexpr auto kName = "$maxN"_sd;
+ explicit AccumulatorMaxN(ExpressionContext* const expCtx)
+ : AccumulatorMinMaxN(expCtx, Sense::kMax) {}
+
+ static const char* getName();
+
+ static boost::intrusive_ptr<AccumulatorState> create(ExpressionContext* const expCtx);
+};
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index d5a07d0b3be..50e32eb8ffa 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -39,9 +39,11 @@
#include "mongo/db/pipeline/accumulation_statement.h"
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/accumulator_for_window_functions.h"
+#include "mongo/db/pipeline/accumulator_multi.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/collation/collator_interface_mock.h"
#include "mongo/dbtests/dbtests.h"
+#include "mongo/idl/server_parameter_test_util.h"
#include "mongo/logv2/log.h"
namespace AccumulatorTests {
@@ -59,12 +61,21 @@ template <typename AccName>
static void assertExpectedResults(
ExpressionContext* const expCtx,
std::initializer_list<std::pair<std::vector<Value>, Value>> operations,
- bool skipMerging = false) {
+ bool skipMerging = false,
+ boost::optional<Value> newGroupValue = boost::none) {
+ auto initializeAccumulator = [&]() -> intrusive_ptr<AccumulatorState> {
+ auto accum = AccName::create(expCtx);
+ if (newGroupValue) {
+ accum->startNewGroup(*newGroupValue);
+ }
+ return accum;
+ };
+
for (auto&& op : operations) {
try {
// Asserts that result equals expected result when not sharded.
{
- auto accum = AccName::create(expCtx);
+ auto accum = initializeAccumulator();
for (auto&& val : op.first) {
accum->process(val, false);
}
@@ -75,12 +86,13 @@ static void assertExpectedResults(
// Asserts that result equals expected result when all input is on one shard.
if (!skipMerging) {
- auto accum = AccName::create(expCtx);
- auto shard = AccName::create(expCtx);
+ auto accum = initializeAccumulator();
+ auto shard = initializeAccumulator();
for (auto&& val : op.first) {
shard->process(val, false);
}
- accum->process(shard->getValue(true), true);
+ auto val = shard->getValue(true);
+ accum->process(val, true);
Value result = accum->getValue(false);
ASSERT_VALUE_EQ(op.second, result);
ASSERT_EQUALS(op.second.getType(), result.getType());
@@ -88,9 +100,9 @@ static void assertExpectedResults(
// Asserts that result equals expected result when each input is on a separate shard.
if (!skipMerging) {
- auto accum = AccName::create(expCtx);
+ auto accum = initializeAccumulator();
for (auto&& val : op.first) {
- auto shard = AccName::create(expCtx);
+ auto shard = initializeAccumulator();
shard->process(val, false);
accum->process(shard->getValue(true), true);
}
@@ -223,6 +235,104 @@ TEST(Accumulators, MinRespectsCollation) {
{{{Value("abc"_sd), Value("cba"_sd)}, Value("cba"_sd)}});
}
+TEST(Accumulators, MinN) {
+ RAIIServerParameterControllerForTest controller("featureFlagExactTopNAccumulator", true);
+ auto expCtx = ExpressionContextForTest{};
+ const auto n = Value(3);
+ assertExpectedResults<AccumulatorMinN>(
+ &expCtx,
+ {
+ // Basic tests.
+ {{Value(3), Value(4), Value(5), Value(100)},
+ {Value(std::vector<Value>{Value(3), Value(4), Value(5)})}},
+ {{Value(10), Value(8), Value(9), Value(7), Value(1)},
+ {Value(std::vector<Value>{Value(1), Value(7), Value(8)})}},
+ {{Value(11.32), Value(91.0), Value(2), Value(701), Value(101)},
+ {Value(std::vector<Value>{Value(2), Value(11.32), Value(91.0)})}},
+
+ // 3 or fewer values results in those values being returned.
+ {{Value(10), Value(8), Value(9)},
+ {Value(std::vector<Value>{Value(8), Value(9), Value(10)})}},
+ {{Value(10)}, {Value(std::vector<Value>{Value(10)})}},
+
+ // Ties are broken arbitrarily.
+ {{Value(10), Value(10), Value(1), Value(10), Value(1), Value(10)},
+ {Value(std::vector<Value>{Value(1), Value(1), Value(10)})}},
+
+ // Null/missing cases (missing and null both get ignored).
+ {{Value(100), Value(BSONNULL), Value(), Value(4), Value(3)},
+ {Value(std::vector<Value>{Value(3), Value(4), Value(100)})}},
+ {{Value(100), Value(), Value(BSONNULL), Value(), Value(3)},
+ {Value(std::vector<Value>{Value(3), Value(100)})}},
+ },
+ false /*skipMerging*/,
+ n);
+}
+
+TEST(Accumulators, MinNRespectsCollation) {
+ RAIIServerParameterControllerForTest controller("featureFlagExactTopNAccumulator", true);
+ auto expCtx = ExpressionContextForTest{};
+ auto collator =
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString);
+ expCtx.setCollator(std::move(collator));
+ const auto n = Value(2);
+ assertExpectedResults<AccumulatorMinN>(
+ &expCtx,
+ {{{Value("abc"_sd), Value("cba"_sd), Value("cca"_sd)},
+ Value(std::vector<Value>{Value("cba"_sd), Value("cca"_sd)})}},
+ false /* skipMerging */,
+ n);
+}
+
+TEST(Accumulators, MaxN) {
+ RAIIServerParameterControllerForTest controller("featureFlagExactTopNAccumulator", true);
+ auto expCtx = ExpressionContextForTest{};
+ const auto n = Value(3);
+ assertExpectedResults<AccumulatorMaxN>(
+ &expCtx,
+ {
+ // Basic tests.
+ {{Value(3), Value(4), Value(5), Value(100)},
+ {Value(std::vector<Value>{Value(100), Value(5), Value(4)})}},
+ {{Value(10), Value(8), Value(9), Value(7), Value(1)},
+ {Value(std::vector<Value>{Value(10), Value(9), Value(8)})}},
+ {{Value(11.32), Value(91.0), Value(2), Value(701), Value(101)},
+ {Value(std::vector<Value>{Value(701), Value(101), Value(91.0)})}},
+
+ // 3 or fewer values results in those values being returned.
+ {{Value(10), Value(8), Value(9)},
+ {Value(std::vector<Value>{Value(10), Value(9), Value(8)})}},
+ {{Value(10)}, {Value(std::vector<Value>{Value(10)})}},
+
+ // Ties are broken arbitrarily.
+ {{Value(1), Value(1), Value(1), Value(10), Value(1), Value(10)},
+ {Value(std::vector<Value>{Value(10), Value(10), Value(1)})}},
+
+ // Null/missing cases (missing and null both get ignored).
+ {{Value(100), Value(BSONNULL), Value(), Value(4), Value(3)},
+ {Value(std::vector<Value>{Value(100), Value(4), Value(3)})}},
+ {{Value(100), Value(), Value(BSONNULL), Value(), Value(3)},
+ {Value(std::vector<Value>{Value(100), Value(3)})}},
+ },
+ false /*skipMerging*/,
+ n);
+}
+
+TEST(Accumulators, MaxNRespectsCollation) {
+ RAIIServerParameterControllerForTest controller("featureFlagExactTopNAccumulator", true);
+ auto expCtx = ExpressionContextForTest{};
+ auto collator =
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString);
+ expCtx.setCollator(std::move(collator));
+ const auto n = Value(2);
+ assertExpectedResults<AccumulatorMaxN>(
+ &expCtx,
+ {{{Value("abc"_sd), Value("cba"_sd), Value("cca"_sd)},
+ Value(std::vector<Value>{Value("abc"_sd), Value("cca"_sd)})}},
+ false /* skipMerging */,
+ n);
+}
+
TEST(Accumulators, Max) {
auto expCtx = ExpressionContextForTest{};
assertExpectedResults<AccumulatorMax>(
diff --git a/src/mongo/db/query/query_knobs.idl b/src/mongo/db/query/query_knobs.idl
index 24e265a7877..5474ead61e4 100644
--- a/src/mongo/db/query/query_knobs.idl
+++ b/src/mongo/db/query/query_knobs.idl
@@ -479,3 +479,13 @@ server_parameters:
cpp_varname: "internalQueryAppendIdToSetWindowFieldsSort"
cpp_vartype: AtomicWord<bool>
default: false
+
+ internalQueryMaxNAccumulatorBytes:
+ description: "Limits the vector of values pushed into a single array while grouping with the 'N' family of accumulators."
+ set_at: [ startup, runtime ]
+ cpp_varname: "internalQueryMaxNAccumulatorBytes"
+ cpp_vartype: AtomicWord<int>
+ default:
+ expr: 100 * 1024 * 1024
+ validator:
+ gt: 0 \ No newline at end of file