summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMihai Andrei <mihai.andrei@10gen.com>2021-08-23 19:40:31 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-08-23 20:17:21 +0000
commitddbc84d44d03216cc3268f4354acda389a2c8c32 (patch)
treef1181d0ff079c04d6148d85912dc21cdf561e08d
parent0f3cd5391f1b67f0f2ed23c77a8bd217b407d8a2 (diff)
downloadmongo-ddbc84d44d03216cc3268f4354acda389a2c8c32.tar.gz
SERVER-57885 Implement $minN and $maxN as window functions
-rw-r--r--jstests/aggregation/extras/window_function_helpers.js44
-rw-r--r--jstests/aggregation/sources/setWindowFields/avg.js4
-rw-r--r--jstests/aggregation/sources/setWindowFields/n_accumulators.js79
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.cpp10
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.h9
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_expression.cpp97
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_expression.h23
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_min_max.h124
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_n_test.cpp224
10 files changed, 567 insertions, 48 deletions
diff --git a/jstests/aggregation/extras/window_function_helpers.js b/jstests/aggregation/extras/window_function_helpers.js
index ed44e800691..49e2af29c4a 100644
--- a/jstests/aggregation/extras/window_function_helpers.js
+++ b/jstests/aggregation/extras/window_function_helpers.js
@@ -65,14 +65,23 @@ function forEachDocumentBoundsCombo(callback) {
*
* The skip/limit values are calculated from the given bounds and the current index.
*
+ * 'accumSpec' is the specification of accumulator being tested and is an object of the form
+ * {accumulatorName: <accumulator arguments>}.
+ *
* 'defaultValue' is used in cases when the skip/limit combination result in $group not getting any
* documents. The most likely scenario is that the window has gone off the side of the partition.
*
* Note that this function assumes that the data in 'coll' has been seeded with the documents from
* the seedWithTickerData() method above.
*/
-function computeAsGroup(
- {coll, partitionKey, accum, bounds, indexInPartition, defaultValue = null}) {
+function computeAsGroup({
+ coll,
+ partitionKey,
+ accumSpec,
+ bounds,
+ indexInPartition,
+ defaultValue = null,
+}) {
const skip = calculateSkip(bounds[0], indexInPartition);
const limit = calculateLimit(bounds[0], bounds[1], indexInPartition);
if (skip < 0 || limit <= 0)
@@ -85,8 +94,7 @@ function computeAsGroup(
prefixPipe = prefixPipe.concat([{$limit: limit}]);
const result =
- coll.aggregate(prefixPipe.concat([{$group: {_id: null, res: {[accum]: "$price"}}}]))
- .toArray();
+ coll.aggregate(prefixPipe.concat([{$group: {_id: null, res: accumSpec}}])).toArray();
// If the window is completely off the edge of the right side of the partition, return null.
if (result.length == 0) {
return defaultValue;
@@ -201,19 +209,19 @@ function assertExplainResult(explainResult) {
* Note that this function assumes that the documents in 'coll' were initialized using the
* seedWithTickerData() method above.
*/
-function testAccumAgainstGroup(coll, accum, onNoResults = null) {
+function testAccumAgainstGroup(coll, accum, onNoResults = null, accumArgs = "$price") {
+ const accumSpec = {[accum]: accumArgs};
forEachPartitionCase(function(partition) {
documentBounds.forEach(function(bounds, index) {
- jsTestLog("Testing accumulator " + accum + " against " + tojson(partition) +
+ jsTestLog("Testing accumulator " + tojson(accumSpec) + " against " + tojson(partition) +
" partition and [" + bounds + "] bounds");
+ let outputSpec = {window: {documents: bounds}};
+ Object.assign(outputSpec, accumSpec);
const pipeline = [
{
- $setWindowFields: {
- partitionBy: partition,
- sortBy: {_id: 1},
- output: {res: {[accum]: "$price", window: {documents: bounds}}}
- },
+ $setWindowFields:
+ {partitionBy: partition, sortBy: {_id: 1}, output: {res: outputSpec}},
},
];
const wfResults = coll.aggregate(pipeline, {allowDiskUse: true}).toArray();
@@ -226,19 +234,19 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) {
groupRes = computeAsGroup({
coll: coll,
partitionKey: {},
- accum: accum,
+ accumSpec: accumSpec,
bounds: bounds,
indexInPartition: indexInPartition,
- defaultValue: onNoResults
+ defaultValue: onNoResults,
});
} else {
groupRes = computeAsGroup({
coll: coll,
partitionKey: {ticker: wfRes.ticker},
- accum: accum,
+ accumSpec: accumSpec,
bounds: bounds,
indexInPartition: indexInPartition,
- defaultValue: onNoResults
+ defaultValue: onNoResults,
});
}
@@ -257,7 +265,7 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) {
// combinations of various window types in the same $setWindowFields stage. This is more of
// a fuzz test so no need to check results.
forEachDocumentBoundsCombo(function(arrayOfBounds) {
- jsTestLog("Testing accumulator " + accum +
+ jsTestLog("Testing accumulator " + tojson(accumSpec) +
" against multiple bounds: " + tojson(arrayOfBounds));
let baseSpec = {
partitionBy: partition,
@@ -265,7 +273,9 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) {
};
let outputFields = {};
arrayOfBounds.forEach(function(bounds, index) {
- outputFields["res" + index] = {[accum]: "$price", window: {documents: bounds}};
+ let outputSpec = {window: {documents: bounds}};
+ Object.assign(outputSpec, accumSpec);
+ outputFields["res" + index] = outputSpec;
});
let specWithOutput = Object.merge(baseSpec, {output: outputFields});
const wfResults =
diff --git a/jstests/aggregation/sources/setWindowFields/avg.js b/jstests/aggregation/sources/setWindowFields/avg.js
index bc2d0a7a6c5..da5ee8925ce 100644
--- a/jstests/aggregation/sources/setWindowFields/avg.js
+++ b/jstests/aggregation/sources/setWindowFields/avg.js
@@ -36,7 +36,7 @@ for (let index = 0; index < results.length; index++) {
let groupRes = computeAsGroup({
coll: coll,
partitionKey: {ticker: results[index].ticker},
- accum: "$avg",
+ accumSpec: {"$avg": "$price"},
bounds: ["unbounded", 0],
indexInPartition: results[index].partIndex,
defaultValue: null
@@ -47,7 +47,7 @@ for (let index = 0; index < results.length; index++) {
groupRes = computeAsGroup({
coll: coll,
partitionKey: {ticker: results[index].ticker},
- accum: "$avg",
+ accumSpec: {"$avg": "$price"},
bounds: ["unbounded", 3],
indexInPartition: results[index].partIndex,
defaultValue: null
diff --git a/jstests/aggregation/sources/setWindowFields/n_accumulators.js b/jstests/aggregation/sources/setWindowFields/n_accumulators.js
new file mode 100644
index 00000000000..81fe0f3c988
--- /dev/null
+++ b/jstests/aggregation/sources/setWindowFields/n_accumulators.js
@@ -0,0 +1,79 @@
+/**
+ * Test that the 'n' family of accumulators work as window functions.
+ */
+(function() {
+"use strict";
+
+load("jstests/aggregation/extras/window_function_helpers.js");
+
+const coll = db[jsTestName()];
+coll.drop();
+
+const isExactTopNEnabled = db.adminCommand({getParameter: 1, featureFlagExactTopNAccumulator: 1})
+ .featureFlagExactTopNAccumulator.value;
+
+if (!isExactTopNEnabled) {
+ // Verify that $minN/$maxN cannot be used if the feature flag is set to false and ignore the
+ // rest of the test.
+ assert.commandFailedWithCode(coll.runCommand("aggregate", {
+ pipeline: [{
+ $setWindowFields: {
+ sortBy: {ts: 1},
+ output: {outputField: {$minN: {n: 3, output: "$foo"}}},
+ }
+ }],
+ cursor: {}
+ }),
+ 5788502);
+ return;
+}
+
+// Create a collection of tickers and prices.
+const nDocsPerTicker = 10;
+seedWithTickerData(coll, nDocsPerTicker);
+
+// TODO SERVER-57884: Add test cases for $firstN/$lastN window functions.
+// TODO SERVER-57886: Add test cases for $top/$bottom/$topN/$bottomN window functions.
+for (const acc of ["$minN", "$maxN"]) {
+ for (const nValue of [4, 7, 12]) {
+ jsTestLog("Testing accumulator " + tojson(acc) + " with 'n' set to " + tojson(nValue));
+ testAccumAgainstGroup(coll, acc, [], {output: "$price", n: nValue});
+ }
+
+ // Verify that the accumulator will not throw if the 'n' expression evaluates to a constant.
+ const pipeline = [
+ {
+ $setWindowFields: {
+ partitionBy: "$ticker",
+ sortBy: {_id: 1},
+ output: {res: {[acc]: {n: {$add: [1, 2]}, output: "$price"}}}
+ },
+ },
+ ];
+
+ assert.doesNotThrow(() => coll.aggregate(pipeline).toArray());
+
+ // Error cases.
+ function testError(accSpec, expectedCode) {
+ assert.throwsWithCode(() => coll.aggregate([{
+ $setWindowFields: {
+ sortBy: {ts: 1},
+ output: {outputField: accSpec},
+ }
+ }]),
+ expectedCode);
+ }
+ // Invalid/missing accumulator specification.
+ testError({[acc]: "non object"}, 5787900);
+ testError({window: {documents: [-1, 1]}}, ErrorCodes.FailedToParse);
+ testError({[acc]: {n: 2}, window: {documents: [-1, 1]}}, 5787907);
+ testError({[acc]: {output: "$foo"}, window: {documents: [-1, 1]}}, 5787906);
+ testError({[acc]: {output: "$foo", n: 2.1}, window: {documents: [-1, 1]}}, 5787903);
+
+ // Invalid window specification.
+ testError({[acc]: {output: "$foo", n: 2.0}, window: [-1, 1]}, ErrorCodes.FailedToParse);
+
+ // Non constant argument for 'n'.
+ testError({[acc]: {output: "$foo", n: "$a"}, window: {documents: [-1, 1]}}, 5787902);
+}
+})(); \ No newline at end of file
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 463b05d0f35..d8ddcf85fc9 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -514,6 +514,7 @@ env.CppUnitTest(
'window_function/window_function_exec_removable_test.cpp',
'window_function/window_function_integral_test.cpp',
'window_function/window_function_min_max_test.cpp',
+ 'window_function/window_function_n_test.cpp',
'window_function/window_function_push_test.cpp',
'window_function/window_function_std_dev_test.cpp',
'window_function/window_function_exec_first_last_test.cpp',
diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp
index 714d7ffaeb3..ecef3202f8c 100644
--- a/src/mongo/db/pipeline/accumulator_multi.cpp
+++ b/src/mongo/db/pipeline/accumulator_multi.cpp
@@ -60,25 +60,27 @@ REGISTER_ACCUMULATOR_WITH_MIN_VERSION(
AccumulatorFirstLastN::parseFirstLastN<FirstLastSense::kLast>,
ServerGlobalParams::FeatureCompatibility::Version::kVersion51);
// TODO SERVER-57881 Add $firstN/$lastN as expressions.
-// TODO SERVER-57885 Add $minN/$maxN as window functions.
// TODO SERVER-57884 Add $firstN/$lastN as window functions.
AccumulatorN::AccumulatorN(ExpressionContext* const expCtx)
: AccumulatorState(expCtx), _maxMemUsageBytes(internalQueryMaxNAccumulatorBytes.load()) {}
-void AccumulatorN::startNewGroup(const Value& input) {
+long long AccumulatorN::validateN(const Value& input) {
// Obtain the value for 'n' and error if it's not a positive integral.
uassert(5787902,
str::stream() << "Value for 'n' must be of integral type, but found "
<< input.toString(),
- isNumericBSONType(input.getType()));
+ input.numeric());
auto n = input.coerceToLong();
uassert(5787903,
str::stream() << "Value for 'n' must be of integral type, but found "
<< input.toString(),
n == input.coerceToDouble());
uassert(5787908, str::stream() << "'n' must be greater than 0, found " << n, n > 0);
- _n = n;
+ return n;
+}
+void AccumulatorN::startNewGroup(const Value& input) {
+ _n = validateN(input);
}
void AccumulatorN::processInternal(const Value& input, bool merging) {
diff --git a/src/mongo/db/pipeline/accumulator_multi.h b/src/mongo/db/pipeline/accumulator_multi.h
index e14d3cabc60..f15063fdcb5 100644
--- a/src/mongo/db/pipeline/accumulator_multi.h
+++ b/src/mongo/db/pipeline/accumulator_multi.h
@@ -46,9 +46,16 @@ class AccumulatorN : public AccumulatorState {
public:
AccumulatorN(ExpressionContext* expCtx);
+ /**
+ * Verifies that 'input' is a positive integer.
+ */
+ static long long validateN(const Value& input);
+
void processInternal(const Value& input, bool merging) final;
- // Initialize 'n' with 'input'. In particular, verifies that 'input' is a positive integer.
+ /**
+ * Initialize 'n' with 'input'.
+ */
void startNewGroup(const Value& input) final;
/**
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.cpp b/src/mongo/db/pipeline/window_function/window_function_expression.cpp
index 9146b5d5b79..0a37571d02e 100644
--- a/src/mongo/db/pipeline/window_function/window_function_expression.cpp
+++ b/src/mongo/db/pipeline/window_function/window_function_expression.cpp
@@ -42,15 +42,19 @@
#include "mongo/db/pipeline/window_function/window_function_exec_derivative.h"
#include "mongo/db/pipeline/window_function/window_function_exec_first_last.h"
#include "mongo/db/pipeline/window_function/window_function_expression.h"
+#include "mongo/db/pipeline/window_function/window_function_min_max.h"
using boost::intrusive_ptr;
using boost::optional;
namespace mongo::window_function {
using namespace std::string_literals;
+using MinMaxSense = AccumulatorMinMax::Sense;
REGISTER_WINDOW_FUNCTION(derivative, ExpressionDerivative::parse);
REGISTER_WINDOW_FUNCTION(first, ExpressionFirst::parse);
REGISTER_WINDOW_FUNCTION(last, ExpressionLast::parse);
+REGISTER_WINDOW_FUNCTION(minN, ExpressionMinMaxN<MinMaxSense::kMin>::parse);
+REGISTER_WINDOW_FUNCTION(maxN, ExpressionMinMaxN<MinMaxSense::kMax>::parse);
StringMap<Expression::Parser> Expression::parserMap;
@@ -215,6 +219,99 @@ boost::intrusive_ptr<Expression> ExpressionFirstLast::parse(
}
}
+template <AccumulatorMinMax::Sense S>
+boost::intrusive_ptr<Expression> ExpressionMinMaxN<S>::parse(
+ BSONObj obj, const boost::optional<SortPattern>& sortBy, ExpressionContext* expCtx) {
+ auto name = [] {
+ if constexpr (S == MinMaxSense::kMin) {
+ return AccumulatorMinN::getName();
+ } else {
+ return AccumulatorMaxN::getName();
+ }
+ }();
+ uassert(5788502,
+ str::stream() << "Cannot create " << name
+ << " accumulator in $setWindowFields"
+ " if feature flag is disabled",
+ feature_flags::gFeatureFlagExactTopNAccumulator.isEnabledAndIgnoreFCV());
+
+ boost::intrusive_ptr<::mongo::Expression> nExpr;
+ boost::intrusive_ptr<::mongo::Expression> outputExpr;
+ boost::optional<WindowBounds> bounds;
+ for (auto&& elem : obj) {
+ auto fieldName = elem.fieldNameStringData();
+ if (fieldName == name) {
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "saw multiple specifications for '" << name << "' expression",
+ !(nExpr || outputExpr));
+ auto accExpr =
+ AccumulatorMinMaxN::parseMinMaxN<S>(expCtx, elem, expCtx->variablesParseState);
+ nExpr = accExpr.initializer;
+ outputExpr = accExpr.argument;
+ } else if (fieldName == kWindowArg) {
+ uassert(ErrorCodes::FailedToParse,
+ "'window' field must be an object",
+ obj[kWindowArg].type() == BSONType::Object);
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "saw multiple 'window' fields in '" << name << "' expression",
+ bounds == boost::none);
+ bounds = WindowBounds::parse(elem.embeddedObject(), sortBy, expCtx);
+ } else {
+ uasserted(ErrorCodes::FailedToParse,
+ str::stream() << name << " got unexpected argument: " << fieldName);
+ }
+ }
+
+ // The default window bounds are [unbounded, unbounded].
+ if (!bounds) {
+ bounds = WindowBounds::defaultBounds();
+ }
+ tassert(5788500,
+ str::stream() << "missing accumulator specification for " << name,
+ nExpr && outputExpr);
+ return make_intrusive<ExpressionMinMaxN<S>>(
+ expCtx, std::move(outputExpr), name, *bounds, std::move(nExpr));
+}
+
+template <AccumulatorMinMax::Sense S>
+boost::intrusive_ptr<AccumulatorState> ExpressionMinMaxN<S>::buildAccumulatorOnly() const {
+ boost::intrusive_ptr<AccumulatorState> acc;
+ if constexpr (S == AccumulatorMinMax::Sense::kMin) {
+ acc = AccumulatorMinN::create(_expCtx);
+ } else {
+ acc = AccumulatorMaxN::create(_expCtx);
+ }
+
+ // Initialize 'n' for our accumulator. Note that 'n' must be a constant.
+ auto nVal = _nExpr->evaluate({}, &_expCtx->variables);
+ uassert(5788501,
+ str::stream() << "Expression for 'n' " << _nExpr->serialize(false).toString()
+ << " must evaluate to a numeric constant when used in $setWindowFields",
+ nVal.numeric());
+ acc->startNewGroup(nVal);
+ return acc;
+}
+
+template <AccumulatorMinMax::Sense S>
+std::unique_ptr<WindowFunctionState> ExpressionMinMaxN<S>::buildRemovable() const {
+ return WindowFunctionMinMaxN<S>::create(
+ _expCtx, AccumulatorN::validateN(_nExpr->evaluate({}, &_expCtx->variables)));
+}
+
+template <AccumulatorMinMax::Sense S>
+Value ExpressionMinMaxN<S>::serialize(boost::optional<ExplainOptions::Verbosity> explain) const {
+ MutableDocument result;
+
+ MutableDocument exprSpec;
+ AccumulatorN::serializeHelper(_nExpr, _input, static_cast<bool>(explain), exprSpec);
+ result[_accumulatorName] = exprSpec.freezeToValue();
+
+ MutableDocument windowField;
+ _bounds.serialize(windowField);
+ result[kWindowArg] = windowField.freezeToValue();
+ return result.freezeToValue();
+}
+
MONGO_INITIALIZER_GROUP(BeginWindowFunctionRegistration,
("default"),
("EndWindowFunctionRegistration"))
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h
index e1b2d17a338..4cbbd88914e 100644
--- a/src/mongo/db/pipeline/window_function/window_function_expression.h
+++ b/src/mongo/db/pipeline/window_function/window_function_expression.h
@@ -720,4 +720,27 @@ public:
}
};
+template <AccumulatorMinMax::Sense S>
+class ExpressionMinMaxN : public Expression {
+public:
+ ExpressionMinMaxN(ExpressionContext* expCtx,
+ boost::intrusive_ptr<::mongo::Expression> input,
+ std::string name,
+ WindowBounds bounds,
+ boost::intrusive_ptr<::mongo::Expression> nExpr)
+ : Expression(expCtx, std::move(name), std::move(input), std::move(bounds)),
+ _nExpr(std::move(nExpr)) {}
+ static boost::intrusive_ptr<Expression> parse(BSONObj obj,
+ const boost::optional<SortPattern>& sortBy,
+ ExpressionContext* expCtx);
+
+ boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final;
+
+ std::unique_ptr<WindowFunctionState> buildRemovable() const final;
+
+ Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final;
+
+private:
+ boost::intrusive_ptr<::mongo::Expression> _nExpr;
+};
} // namespace mongo::window_function
diff --git a/src/mongo/db/pipeline/window_function/window_function_min_max.h b/src/mongo/db/pipeline/window_function/window_function_min_max.h
index 0ff212484c0..8c5d2822696 100644
--- a/src/mongo/db/pipeline/window_function/window_function_min_max.h
+++ b/src/mongo/db/pipeline/window_function/window_function_min_max.h
@@ -35,30 +35,18 @@
namespace mongo {
template <AccumulatorMinMax::Sense sense>
-class WindowFunctionMinMax : public WindowFunctionState {
+class WindowFunctionMinMaxCommon : public WindowFunctionState {
public:
- static inline const Value kDefault = Value{BSONNULL};
-
- static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) {
- return std::make_unique<WindowFunctionMinMax<sense>>(expCtx);
- }
-
- explicit WindowFunctionMinMax(ExpressionContext* const expCtx)
- : WindowFunctionState(expCtx),
- _values(_expCtx->getValueComparator().makeOrderedValueMultiset()) {
- _memUsageBytes = sizeof(*this);
- }
-
- void add(Value value) final {
+ void add(Value value) override {
_memUsageBytes += value.getApproximateSize();
_values.insert(std::move(value));
}
- void remove(Value value) final {
+ void remove(Value value) override {
// std::multiset::insert is guaranteed to put the element after any equal elements
// already in the container. So find() / erase() will remove the oldest equal element,
// which is what we want, to satisfy "remove() undoes add() when called in FIFO order".
- auto iter = _values.find(std::move(value));
+ auto iter = _values.find(value);
tassert(5371400, "Can't remove from an empty WindowFunctionMinMax", iter != _values.end());
_memUsageBytes -= iter->getApproximateSize();
_values.erase(iter);
@@ -69,23 +57,111 @@ public:
_memUsageBytes = sizeof(*this);
}
+protected:
+ // Constructor hidden so that only instances of the derived types can be created.
+ explicit WindowFunctionMinMaxCommon(ExpressionContext* const expCtx)
+ : WindowFunctionState(expCtx),
+ _values(_expCtx->getValueComparator().makeOrderedValueMultiset()) {}
+
+ // Holds all the values in the window, in order, with constant-time access to both ends.
+ ValueMultiset _values;
+};
+
+template <AccumulatorMinMax::Sense sense>
+class WindowFunctionMinMax : public WindowFunctionMinMaxCommon<sense> {
+public:
+ using WindowFunctionMinMaxCommon<sense>::_values;
+ using WindowFunctionMinMaxCommon<sense>::_memUsageBytes;
+
+ static inline const Value kDefault = Value{BSONNULL};
+
+ static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) {
+ return std::make_unique<WindowFunctionMinMax<sense>>(expCtx);
+ }
+
+ explicit WindowFunctionMinMax(ExpressionContext* const expCtx)
+ : WindowFunctionMinMaxCommon<sense>(expCtx) {
+ _memUsageBytes = sizeof(*this);
+ }
+
Value getValue() const final {
if (_values.empty())
return kDefault;
- switch (sense) {
- case AccumulatorMinMax::Sense::kMin:
- return *_values.begin();
- case AccumulatorMinMax::Sense::kMax:
- return *_values.rbegin();
+ if constexpr (sense == AccumulatorMinMax::Sense::kMin) {
+ return *_values.begin();
+ } else {
+ return *_values.rbegin();
}
MONGO_UNREACHABLE_TASSERT(5371401);
}
+};
-protected:
- // Holds all the values in the window, in order, with constant-time access to both ends.
- ValueMultiset _values;
+template <AccumulatorMinMax::Sense sense>
+class WindowFunctionMinMaxN : public WindowFunctionMinMaxCommon<sense> {
+public:
+ using WindowFunctionMinMaxCommon<sense>::_values;
+ using WindowFunctionMinMaxCommon<sense>::_memUsageBytes;
+
+ static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx,
+ long long n) {
+ return std::make_unique<WindowFunctionMinMaxN<sense>>(expCtx, n);
+ }
+ explicit WindowFunctionMinMaxN(ExpressionContext* const expCtx, long long n)
+ : WindowFunctionMinMaxCommon<sense>(expCtx), _n(n) {
+ _memUsageBytes = sizeof(*this);
+ }
+
+ void add(Value value) final {
+ // Ignore nullish values.
+ if (value.nullish())
+ return;
+ WindowFunctionMinMaxCommon<sense>::add(std::move(value));
+ }
+
+ void remove(Value value) final {
+ // Ignore nullish values.
+ if (value.nullish())
+ return;
+ WindowFunctionMinMaxCommon<sense>::remove(std::move(value));
+ }
+
+ Value getValue() const final {
+ if (_values.empty()) {
+ return Value(std::vector<Value>());
+ }
+
+ auto processVal = [&](auto begin, auto end, size_t size) -> Value {
+ auto n = static_cast<size_t>(_n);
+
+ // If 'n' is greater than the size of the current window, then return all the values.
+ if (n >= size) {
+ return Value(std::vector(begin, end));
+ } else {
+ std::vector<Value> result;
+ result.reserve(n);
+ auto it = begin;
+ for (size_t i = 0; i < n; ++i, ++it) {
+ result.push_back(*it);
+ }
+ return Value(std::move(result));
+ }
+ };
+
+ auto size = _values.size();
+ if constexpr (sense == AccumulatorMinMax::Sense::kMin) {
+ return processVal(_values.begin(), _values.end(), size);
+ } else {
+ return processVal(_values.rbegin(), _values.rend(), size);
+ }
+ }
+
+
+private:
+ long long _n;
};
using WindowFunctionMin = WindowFunctionMinMax<AccumulatorMinMax::Sense::kMin>;
using WindowFunctionMax = WindowFunctionMinMax<AccumulatorMinMax::Sense::kMax>;
+using WindowFunctionMinN = WindowFunctionMinMaxN<AccumulatorMinMax::Sense::kMin>;
+using WindowFunctionMaxN = WindowFunctionMinMaxN<AccumulatorMinMax::Sense::kMax>;
} // namespace mongo
diff --git a/src/mongo/db/pipeline/window_function/window_function_n_test.cpp b/src/mongo/db/pipeline/window_function/window_function_n_test.cpp
new file mode 100644
index 00000000000..246fb42d77d
--- /dev/null
+++ b/src/mongo/db/pipeline/window_function/window_function_n_test.cpp
@@ -0,0 +1,224 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/db/exec/document_value/document_value_test_util.h"
+#include "mongo/db/pipeline/aggregation_context_fixture.h"
+#include "mongo/db/pipeline/window_function/window_function_min_max.h"
+#include "mongo/db/query/collation/collator_interface_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+// TODO SERVER-57884: Add test cases for $firstN/$lastN window functions.
+// TODO SERVER-57886: Add test cases for $top/$bottom/$topN/$bottomN window functions.
+class WindowFunctionMinMaxNTest : public AggregationContextFixture {
+public:
+ static constexpr auto kNarg = 3LL;
+ WindowFunctionMinMaxNTest()
+ : expCtx(getExpCtx()), minThree(expCtx.get(), kNarg), maxThree(expCtx.get(), kNarg) {
+ auto collator = std::make_unique<CollatorInterfaceMock>(
+ CollatorInterfaceMock::MockType::kToLowerString);
+ expCtx->setCollator(std::move(collator));
+ }
+
+ boost::intrusive_ptr<ExpressionContext> expCtx;
+ WindowFunctionMinN minThree;
+ WindowFunctionMaxN maxThree;
+};
+
+TEST_F(WindowFunctionMinMaxNTest, EmptyWindow) {
+ auto test = [](auto windowFunction) {
+ ASSERT_VALUE_EQ(windowFunction.getValue(), Value{BSONArray()});
+
+ // No matter how many nullish values we insert, we should still get back the empty array.
+ windowFunction.add(Value());
+ windowFunction.add(Value(BSONNULL));
+ windowFunction.add(Value());
+ windowFunction.add(Value(BSONNULL));
+ ASSERT_VALUE_EQ(windowFunction.getValue(), Value{BSONArray()});
+
+ // Add a value and show that removing nullish has no effect.
+ windowFunction.add(Value{3});
+ ASSERT_VALUE_EQ(windowFunction.getValue(), Value{std::vector{Value(3)}});
+
+ windowFunction.remove(Value());
+ windowFunction.remove(Value(BSONNULL));
+ windowFunction.remove(Value());
+ windowFunction.remove(Value(BSONNULL));
+ ASSERT_VALUE_EQ(windowFunction.getValue(), Value{std::vector{Value(3)}});
+ };
+ test(minThree);
+ test(maxThree);
+}
+
+TEST_F(WindowFunctionMinMaxNTest, WindowSmallerThanN) {
+ minThree.add(Value{5});
+ minThree.add(Value{7});
+
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(5), Value(7)}));
+
+ maxThree.add(Value{5});
+ maxThree.add(Value{7});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(7), Value(5)}));
+}
+
+TEST_F(WindowFunctionMinMaxNTest, WindowContainsDuplicates) {
+ minThree.add(Value{5});
+ minThree.add(Value{7});
+ minThree.add(Value{7});
+ minThree.add(Value{7});
+ minThree.add(Value{7});
+
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(5), Value(7), Value(7)}));
+
+ maxThree.add(Value{5});
+ maxThree.add(Value{5});
+ maxThree.add(Value{5});
+ maxThree.add(Value{5});
+ maxThree.add(Value{5});
+ maxThree.add(Value{7});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(7), Value(5), Value(5)}));
+}
+
+TEST_F(WindowFunctionMinMaxNTest, BasicCorrectnessTest) {
+ minThree.add(Value{5});
+ minThree.add(Value{10});
+ minThree.add(Value{6});
+ minThree.add(Value{12});
+ minThree.add(Value{7});
+ minThree.add(Value{3});
+
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(5), Value(6)}));
+
+ minThree.remove(Value{5});
+ minThree.remove(Value{10});
+
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(6), Value(7)}));
+ minThree.remove(Value{6});
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(7), Value(12)}));
+
+ minThree.remove(Value{12});
+ minThree.remove(Value{7});
+ minThree.remove(Value{3});
+ ASSERT_VALUE_EQ(minThree.getValue(), Value{BSONArray()});
+
+ maxThree.add(Value{5});
+ maxThree.add(Value{9});
+ maxThree.add(Value{12});
+ maxThree.add(Value{11});
+ maxThree.add(Value{3});
+ maxThree.add(Value{7});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(12), Value(11), Value(9)}));
+
+ maxThree.remove(Value{5});
+ maxThree.remove(Value{9});
+ maxThree.remove(Value{12});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(11), Value(7), Value(3)}));
+
+ maxThree.remove(Value{11});
+ maxThree.remove(Value{3});
+ maxThree.remove(Value{7});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value{BSONArray()});
+}
+
+TEST_F(WindowFunctionMinMaxNTest, MixNullsAndNonNulls) {
+ // Add four values, half of which are null/missing. We should only return the two non-nulls.
+ minThree.add(Value{4});
+ minThree.add(Value());
+ minThree.add(Value(BSONNULL));
+ minThree.add(Value{1});
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(1), Value(4)}));
+
+ // Add a couple more values. We should still get no nulls/missing.
+ minThree.add(Value{3});
+ minThree.add(Value());
+ minThree.add(Value(BSONNULL));
+ minThree.add(Value{2});
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(1), Value(2), Value(3)}));
+
+ // Add four values, half of which are null/missing. We should only return the two non-nulls.
+ maxThree.add(Value{4});
+ maxThree.add(Value());
+ maxThree.add(Value(BSONNULL));
+ maxThree.add(Value{1});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(4), Value(1)}));
+
+ // Add a couple more values. We should still get no nulls/missing.
+ maxThree.add(Value{3});
+ maxThree.add(Value());
+ maxThree.add(Value(BSONNULL));
+ maxThree.add(Value{2});
+ ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(4), Value(3), Value(2)}));
+}
+
+TEST_F(WindowFunctionMinMaxNTest, Ties) {
+ // When two elements tie (compare equal), remove() can't pick an arbitrary one,
+ // because that would break the invariant that 'add(x); add(y); remove(x)' is equivalent to
+ // 'add(y)'.
+
+ auto x = Value{"foo"_sd};
+ auto y = Value{"FOO"_sd};
+ // x and y are distinguishable,
+ ASSERT_VALUE_NE(x, y);
+ // but they compare equal according to the ordering.
+ ASSERT(expCtx->getValueComparator().evaluate(x == y));
+
+ minThree.add(x);
+ minThree.add(y);
+ minThree.remove(x);
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{y}));
+
+ minThree.add(x);
+ minThree.add(y);
+ minThree.remove(x);
+
+ // Here, we expect ["foo","FOO"] because we remove the first added entry that compares equal
+ // to 'x', which is the first instance of 'y'.
+ ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{x, y}));
+}
+
+TEST_F(WindowFunctionMinMaxNTest, TracksMemoryUsageOnAddAndRemove) {
+ size_t trackingSize = sizeof(WindowFunctionMinN);
+ ASSERT_EQ(minThree.getApproximateSize(), trackingSize);
+
+ auto largeStr = Value{"$minN/maxN are great window functions"_sd};
+ minThree.add(largeStr);
+ trackingSize += largeStr.getApproximateSize();
+ ASSERT_EQ(minThree.getApproximateSize(), trackingSize);
+
+ minThree.add(largeStr);
+ trackingSize += largeStr.getApproximateSize();
+ ASSERT_EQ(minThree.getApproximateSize(), trackingSize);
+
+ minThree.remove(largeStr);
+ trackingSize -= largeStr.getApproximateSize();
+ ASSERT_EQ(minThree.getApproximateSize(), trackingSize);
+}
+} // namespace
+} // namespace mongo