diff options
author | Mihai Andrei <mihai.andrei@10gen.com> | 2021-08-23 19:40:31 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-08-23 20:17:21 +0000 |
commit | ddbc84d44d03216cc3268f4354acda389a2c8c32 (patch) | |
tree | f1181d0ff079c04d6148d85912dc21cdf561e08d | |
parent | 0f3cd5391f1b67f0f2ed23c77a8bd217b407d8a2 (diff) | |
download | mongo-ddbc84d44d03216cc3268f4354acda389a2c8c32.tar.gz |
SERVER-57885 Implement $minN and $maxN as window functions
10 files changed, 567 insertions, 48 deletions
diff --git a/jstests/aggregation/extras/window_function_helpers.js b/jstests/aggregation/extras/window_function_helpers.js index ed44e800691..49e2af29c4a 100644 --- a/jstests/aggregation/extras/window_function_helpers.js +++ b/jstests/aggregation/extras/window_function_helpers.js @@ -65,14 +65,23 @@ function forEachDocumentBoundsCombo(callback) { * * The skip/limit values are calculated from the given bounds and the current index. * + * 'accumSpec' is the specification of accumulator being tested and is an object of the form + * {accumulatorName: <accumulator arguments>}. + * * 'defaultValue' is used in cases when the skip/limit combination result in $group not getting any * documents. The most likely scenario is that the window has gone off the side of the partition. * * Note that this function assumes that the data in 'coll' has been seeded with the documents from * the seedWithTickerData() method above. */ -function computeAsGroup( - {coll, partitionKey, accum, bounds, indexInPartition, defaultValue = null}) { +function computeAsGroup({ + coll, + partitionKey, + accumSpec, + bounds, + indexInPartition, + defaultValue = null, +}) { const skip = calculateSkip(bounds[0], indexInPartition); const limit = calculateLimit(bounds[0], bounds[1], indexInPartition); if (skip < 0 || limit <= 0) @@ -85,8 +94,7 @@ function computeAsGroup( prefixPipe = prefixPipe.concat([{$limit: limit}]); const result = - coll.aggregate(prefixPipe.concat([{$group: {_id: null, res: {[accum]: "$price"}}}])) - .toArray(); + coll.aggregate(prefixPipe.concat([{$group: {_id: null, res: accumSpec}}])).toArray(); // If the window is completely off the edge of the right side of the partition, return null. if (result.length == 0) { return defaultValue; @@ -201,19 +209,19 @@ function assertExplainResult(explainResult) { * Note that this function assumes that the documents in 'coll' were initialized using the * seedWithTickerData() method above. */ -function testAccumAgainstGroup(coll, accum, onNoResults = null) { +function testAccumAgainstGroup(coll, accum, onNoResults = null, accumArgs = "$price") { + const accumSpec = {[accum]: accumArgs}; forEachPartitionCase(function(partition) { documentBounds.forEach(function(bounds, index) { - jsTestLog("Testing accumulator " + accum + " against " + tojson(partition) + + jsTestLog("Testing accumulator " + tojson(accumSpec) + " against " + tojson(partition) + " partition and [" + bounds + "] bounds"); + let outputSpec = {window: {documents: bounds}}; + Object.assign(outputSpec, accumSpec); const pipeline = [ { - $setWindowFields: { - partitionBy: partition, - sortBy: {_id: 1}, - output: {res: {[accum]: "$price", window: {documents: bounds}}} - }, + $setWindowFields: + {partitionBy: partition, sortBy: {_id: 1}, output: {res: outputSpec}}, }, ]; const wfResults = coll.aggregate(pipeline, {allowDiskUse: true}).toArray(); @@ -226,19 +234,19 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) { groupRes = computeAsGroup({ coll: coll, partitionKey: {}, - accum: accum, + accumSpec: accumSpec, bounds: bounds, indexInPartition: indexInPartition, - defaultValue: onNoResults + defaultValue: onNoResults, }); } else { groupRes = computeAsGroup({ coll: coll, partitionKey: {ticker: wfRes.ticker}, - accum: accum, + accumSpec: accumSpec, bounds: bounds, indexInPartition: indexInPartition, - defaultValue: onNoResults + defaultValue: onNoResults, }); } @@ -257,7 +265,7 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) { // combinations of various window types in the same $setWindowFields stage. This is more of // a fuzz test so no need to check results. forEachDocumentBoundsCombo(function(arrayOfBounds) { - jsTestLog("Testing accumulator " + accum + + jsTestLog("Testing accumulator " + tojson(accumSpec) + " against multiple bounds: " + tojson(arrayOfBounds)); let baseSpec = { partitionBy: partition, @@ -265,7 +273,9 @@ function testAccumAgainstGroup(coll, accum, onNoResults = null) { }; let outputFields = {}; arrayOfBounds.forEach(function(bounds, index) { - outputFields["res" + index] = {[accum]: "$price", window: {documents: bounds}}; + let outputSpec = {window: {documents: bounds}}; + Object.assign(outputSpec, accumSpec); + outputFields["res" + index] = outputSpec; }); let specWithOutput = Object.merge(baseSpec, {output: outputFields}); const wfResults = diff --git a/jstests/aggregation/sources/setWindowFields/avg.js b/jstests/aggregation/sources/setWindowFields/avg.js index bc2d0a7a6c5..da5ee8925ce 100644 --- a/jstests/aggregation/sources/setWindowFields/avg.js +++ b/jstests/aggregation/sources/setWindowFields/avg.js @@ -36,7 +36,7 @@ for (let index = 0; index < results.length; index++) { let groupRes = computeAsGroup({ coll: coll, partitionKey: {ticker: results[index].ticker}, - accum: "$avg", + accumSpec: {"$avg": "$price"}, bounds: ["unbounded", 0], indexInPartition: results[index].partIndex, defaultValue: null @@ -47,7 +47,7 @@ for (let index = 0; index < results.length; index++) { groupRes = computeAsGroup({ coll: coll, partitionKey: {ticker: results[index].ticker}, - accum: "$avg", + accumSpec: {"$avg": "$price"}, bounds: ["unbounded", 3], indexInPartition: results[index].partIndex, defaultValue: null diff --git a/jstests/aggregation/sources/setWindowFields/n_accumulators.js b/jstests/aggregation/sources/setWindowFields/n_accumulators.js new file mode 100644 index 00000000000..81fe0f3c988 --- /dev/null +++ b/jstests/aggregation/sources/setWindowFields/n_accumulators.js @@ -0,0 +1,79 @@ +/** + * Test that the 'n' family of accumulators work as window functions. + */ +(function() { +"use strict"; + +load("jstests/aggregation/extras/window_function_helpers.js"); + +const coll = db[jsTestName()]; +coll.drop(); + +const isExactTopNEnabled = db.adminCommand({getParameter: 1, featureFlagExactTopNAccumulator: 1}) + .featureFlagExactTopNAccumulator.value; + +if (!isExactTopNEnabled) { + // Verify that $minN/$maxN cannot be used if the feature flag is set to false and ignore the + // rest of the test. + assert.commandFailedWithCode(coll.runCommand("aggregate", { + pipeline: [{ + $setWindowFields: { + sortBy: {ts: 1}, + output: {outputField: {$minN: {n: 3, output: "$foo"}}}, + } + }], + cursor: {} + }), + 5788502); + return; +} + +// Create a collection of tickers and prices. +const nDocsPerTicker = 10; +seedWithTickerData(coll, nDocsPerTicker); + +// TODO SERVER-57884: Add test cases for $firstN/$lastN window functions. +// TODO SERVER-57886: Add test cases for $top/$bottom/$topN/$bottomN window functions. +for (const acc of ["$minN", "$maxN"]) { + for (const nValue of [4, 7, 12]) { + jsTestLog("Testing accumulator " + tojson(acc) + " with 'n' set to " + tojson(nValue)); + testAccumAgainstGroup(coll, acc, [], {output: "$price", n: nValue}); + } + + // Verify that the accumulator will not throw if the 'n' expression evaluates to a constant. + const pipeline = [ + { + $setWindowFields: { + partitionBy: "$ticker", + sortBy: {_id: 1}, + output: {res: {[acc]: {n: {$add: [1, 2]}, output: "$price"}}} + }, + }, + ]; + + assert.doesNotThrow(() => coll.aggregate(pipeline).toArray()); + + // Error cases. + function testError(accSpec, expectedCode) { + assert.throwsWithCode(() => coll.aggregate([{ + $setWindowFields: { + sortBy: {ts: 1}, + output: {outputField: accSpec}, + } + }]), + expectedCode); + } + // Invalid/missing accumulator specification. + testError({[acc]: "non object"}, 5787900); + testError({window: {documents: [-1, 1]}}, ErrorCodes.FailedToParse); + testError({[acc]: {n: 2}, window: {documents: [-1, 1]}}, 5787907); + testError({[acc]: {output: "$foo"}, window: {documents: [-1, 1]}}, 5787906); + testError({[acc]: {output: "$foo", n: 2.1}, window: {documents: [-1, 1]}}, 5787903); + + // Invalid window specification. + testError({[acc]: {output: "$foo", n: 2.0}, window: [-1, 1]}, ErrorCodes.FailedToParse); + + // Non constant argument for 'n'. + testError({[acc]: {output: "$foo", n: "$a"}, window: {documents: [-1, 1]}}, 5787902); +} +})();
\ No newline at end of file diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 463b05d0f35..d8ddcf85fc9 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -514,6 +514,7 @@ env.CppUnitTest( 'window_function/window_function_exec_removable_test.cpp', 'window_function/window_function_integral_test.cpp', 'window_function/window_function_min_max_test.cpp', + 'window_function/window_function_n_test.cpp', 'window_function/window_function_push_test.cpp', 'window_function/window_function_std_dev_test.cpp', 'window_function/window_function_exec_first_last_test.cpp', diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp index 714d7ffaeb3..ecef3202f8c 100644 --- a/src/mongo/db/pipeline/accumulator_multi.cpp +++ b/src/mongo/db/pipeline/accumulator_multi.cpp @@ -60,25 +60,27 @@ REGISTER_ACCUMULATOR_WITH_MIN_VERSION( AccumulatorFirstLastN::parseFirstLastN<FirstLastSense::kLast>, ServerGlobalParams::FeatureCompatibility::Version::kVersion51); // TODO SERVER-57881 Add $firstN/$lastN as expressions. -// TODO SERVER-57885 Add $minN/$maxN as window functions. // TODO SERVER-57884 Add $firstN/$lastN as window functions. AccumulatorN::AccumulatorN(ExpressionContext* const expCtx) : AccumulatorState(expCtx), _maxMemUsageBytes(internalQueryMaxNAccumulatorBytes.load()) {} -void AccumulatorN::startNewGroup(const Value& input) { +long long AccumulatorN::validateN(const Value& input) { // Obtain the value for 'n' and error if it's not a positive integral. uassert(5787902, str::stream() << "Value for 'n' must be of integral type, but found " << input.toString(), - isNumericBSONType(input.getType())); + input.numeric()); auto n = input.coerceToLong(); uassert(5787903, str::stream() << "Value for 'n' must be of integral type, but found " << input.toString(), n == input.coerceToDouble()); uassert(5787908, str::stream() << "'n' must be greater than 0, found " << n, n > 0); - _n = n; + return n; +} +void AccumulatorN::startNewGroup(const Value& input) { + _n = validateN(input); } void AccumulatorN::processInternal(const Value& input, bool merging) { diff --git a/src/mongo/db/pipeline/accumulator_multi.h b/src/mongo/db/pipeline/accumulator_multi.h index e14d3cabc60..f15063fdcb5 100644 --- a/src/mongo/db/pipeline/accumulator_multi.h +++ b/src/mongo/db/pipeline/accumulator_multi.h @@ -46,9 +46,16 @@ class AccumulatorN : public AccumulatorState { public: AccumulatorN(ExpressionContext* expCtx); + /** + * Verifies that 'input' is a positive integer. + */ + static long long validateN(const Value& input); + void processInternal(const Value& input, bool merging) final; - // Initialize 'n' with 'input'. In particular, verifies that 'input' is a positive integer. + /** + * Initialize 'n' with 'input'. + */ void startNewGroup(const Value& input) final; /** diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.cpp b/src/mongo/db/pipeline/window_function/window_function_expression.cpp index 9146b5d5b79..0a37571d02e 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.cpp +++ b/src/mongo/db/pipeline/window_function/window_function_expression.cpp @@ -42,15 +42,19 @@ #include "mongo/db/pipeline/window_function/window_function_exec_derivative.h" #include "mongo/db/pipeline/window_function/window_function_exec_first_last.h" #include "mongo/db/pipeline/window_function/window_function_expression.h" +#include "mongo/db/pipeline/window_function/window_function_min_max.h" using boost::intrusive_ptr; using boost::optional; namespace mongo::window_function { using namespace std::string_literals; +using MinMaxSense = AccumulatorMinMax::Sense; REGISTER_WINDOW_FUNCTION(derivative, ExpressionDerivative::parse); REGISTER_WINDOW_FUNCTION(first, ExpressionFirst::parse); REGISTER_WINDOW_FUNCTION(last, ExpressionLast::parse); +REGISTER_WINDOW_FUNCTION(minN, ExpressionMinMaxN<MinMaxSense::kMin>::parse); +REGISTER_WINDOW_FUNCTION(maxN, ExpressionMinMaxN<MinMaxSense::kMax>::parse); StringMap<Expression::Parser> Expression::parserMap; @@ -215,6 +219,99 @@ boost::intrusive_ptr<Expression> ExpressionFirstLast::parse( } } +template <AccumulatorMinMax::Sense S> +boost::intrusive_ptr<Expression> ExpressionMinMaxN<S>::parse( + BSONObj obj, const boost::optional<SortPattern>& sortBy, ExpressionContext* expCtx) { + auto name = [] { + if constexpr (S == MinMaxSense::kMin) { + return AccumulatorMinN::getName(); + } else { + return AccumulatorMaxN::getName(); + } + }(); + uassert(5788502, + str::stream() << "Cannot create " << name + << " accumulator in $setWindowFields" + " if feature flag is disabled", + feature_flags::gFeatureFlagExactTopNAccumulator.isEnabledAndIgnoreFCV()); + + boost::intrusive_ptr<::mongo::Expression> nExpr; + boost::intrusive_ptr<::mongo::Expression> outputExpr; + boost::optional<WindowBounds> bounds; + for (auto&& elem : obj) { + auto fieldName = elem.fieldNameStringData(); + if (fieldName == name) { + uassert(ErrorCodes::FailedToParse, + str::stream() << "saw multiple specifications for '" << name << "' expression", + !(nExpr || outputExpr)); + auto accExpr = + AccumulatorMinMaxN::parseMinMaxN<S>(expCtx, elem, expCtx->variablesParseState); + nExpr = accExpr.initializer; + outputExpr = accExpr.argument; + } else if (fieldName == kWindowArg) { + uassert(ErrorCodes::FailedToParse, + "'window' field must be an object", + obj[kWindowArg].type() == BSONType::Object); + uassert(ErrorCodes::FailedToParse, + str::stream() << "saw multiple 'window' fields in '" << name << "' expression", + bounds == boost::none); + bounds = WindowBounds::parse(elem.embeddedObject(), sortBy, expCtx); + } else { + uasserted(ErrorCodes::FailedToParse, + str::stream() << name << " got unexpected argument: " << fieldName); + } + } + + // The default window bounds are [unbounded, unbounded]. + if (!bounds) { + bounds = WindowBounds::defaultBounds(); + } + tassert(5788500, + str::stream() << "missing accumulator specification for " << name, + nExpr && outputExpr); + return make_intrusive<ExpressionMinMaxN<S>>( + expCtx, std::move(outputExpr), name, *bounds, std::move(nExpr)); +} + +template <AccumulatorMinMax::Sense S> +boost::intrusive_ptr<AccumulatorState> ExpressionMinMaxN<S>::buildAccumulatorOnly() const { + boost::intrusive_ptr<AccumulatorState> acc; + if constexpr (S == AccumulatorMinMax::Sense::kMin) { + acc = AccumulatorMinN::create(_expCtx); + } else { + acc = AccumulatorMaxN::create(_expCtx); + } + + // Initialize 'n' for our accumulator. Note that 'n' must be a constant. + auto nVal = _nExpr->evaluate({}, &_expCtx->variables); + uassert(5788501, + str::stream() << "Expression for 'n' " << _nExpr->serialize(false).toString() + << " must evaluate to a numeric constant when used in $setWindowFields", + nVal.numeric()); + acc->startNewGroup(nVal); + return acc; +} + +template <AccumulatorMinMax::Sense S> +std::unique_ptr<WindowFunctionState> ExpressionMinMaxN<S>::buildRemovable() const { + return WindowFunctionMinMaxN<S>::create( + _expCtx, AccumulatorN::validateN(_nExpr->evaluate({}, &_expCtx->variables))); +} + +template <AccumulatorMinMax::Sense S> +Value ExpressionMinMaxN<S>::serialize(boost::optional<ExplainOptions::Verbosity> explain) const { + MutableDocument result; + + MutableDocument exprSpec; + AccumulatorN::serializeHelper(_nExpr, _input, static_cast<bool>(explain), exprSpec); + result[_accumulatorName] = exprSpec.freezeToValue(); + + MutableDocument windowField; + _bounds.serialize(windowField); + result[kWindowArg] = windowField.freezeToValue(); + return result.freezeToValue(); +} + MONGO_INITIALIZER_GROUP(BeginWindowFunctionRegistration, ("default"), ("EndWindowFunctionRegistration")) diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h index e1b2d17a338..4cbbd88914e 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.h +++ b/src/mongo/db/pipeline/window_function/window_function_expression.h @@ -720,4 +720,27 @@ public: } }; +template <AccumulatorMinMax::Sense S> +class ExpressionMinMaxN : public Expression { +public: + ExpressionMinMaxN(ExpressionContext* expCtx, + boost::intrusive_ptr<::mongo::Expression> input, + std::string name, + WindowBounds bounds, + boost::intrusive_ptr<::mongo::Expression> nExpr) + : Expression(expCtx, std::move(name), std::move(input), std::move(bounds)), + _nExpr(std::move(nExpr)) {} + static boost::intrusive_ptr<Expression> parse(BSONObj obj, + const boost::optional<SortPattern>& sortBy, + ExpressionContext* expCtx); + + boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final; + + std::unique_ptr<WindowFunctionState> buildRemovable() const final; + + Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final; + +private: + boost::intrusive_ptr<::mongo::Expression> _nExpr; +}; } // namespace mongo::window_function diff --git a/src/mongo/db/pipeline/window_function/window_function_min_max.h b/src/mongo/db/pipeline/window_function/window_function_min_max.h index 0ff212484c0..8c5d2822696 100644 --- a/src/mongo/db/pipeline/window_function/window_function_min_max.h +++ b/src/mongo/db/pipeline/window_function/window_function_min_max.h @@ -35,30 +35,18 @@ namespace mongo { template <AccumulatorMinMax::Sense sense> -class WindowFunctionMinMax : public WindowFunctionState { +class WindowFunctionMinMaxCommon : public WindowFunctionState { public: - static inline const Value kDefault = Value{BSONNULL}; - - static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) { - return std::make_unique<WindowFunctionMinMax<sense>>(expCtx); - } - - explicit WindowFunctionMinMax(ExpressionContext* const expCtx) - : WindowFunctionState(expCtx), - _values(_expCtx->getValueComparator().makeOrderedValueMultiset()) { - _memUsageBytes = sizeof(*this); - } - - void add(Value value) final { + void add(Value value) override { _memUsageBytes += value.getApproximateSize(); _values.insert(std::move(value)); } - void remove(Value value) final { + void remove(Value value) override { // std::multiset::insert is guaranteed to put the element after any equal elements // already in the container. So find() / erase() will remove the oldest equal element, // which is what we want, to satisfy "remove() undoes add() when called in FIFO order". - auto iter = _values.find(std::move(value)); + auto iter = _values.find(value); tassert(5371400, "Can't remove from an empty WindowFunctionMinMax", iter != _values.end()); _memUsageBytes -= iter->getApproximateSize(); _values.erase(iter); @@ -69,23 +57,111 @@ public: _memUsageBytes = sizeof(*this); } +protected: + // Constructor hidden so that only instances of the derived types can be created. + explicit WindowFunctionMinMaxCommon(ExpressionContext* const expCtx) + : WindowFunctionState(expCtx), + _values(_expCtx->getValueComparator().makeOrderedValueMultiset()) {} + + // Holds all the values in the window, in order, with constant-time access to both ends. + ValueMultiset _values; +}; + +template <AccumulatorMinMax::Sense sense> +class WindowFunctionMinMax : public WindowFunctionMinMaxCommon<sense> { +public: + using WindowFunctionMinMaxCommon<sense>::_values; + using WindowFunctionMinMaxCommon<sense>::_memUsageBytes; + + static inline const Value kDefault = Value{BSONNULL}; + + static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx) { + return std::make_unique<WindowFunctionMinMax<sense>>(expCtx); + } + + explicit WindowFunctionMinMax(ExpressionContext* const expCtx) + : WindowFunctionMinMaxCommon<sense>(expCtx) { + _memUsageBytes = sizeof(*this); + } + Value getValue() const final { if (_values.empty()) return kDefault; - switch (sense) { - case AccumulatorMinMax::Sense::kMin: - return *_values.begin(); - case AccumulatorMinMax::Sense::kMax: - return *_values.rbegin(); + if constexpr (sense == AccumulatorMinMax::Sense::kMin) { + return *_values.begin(); + } else { + return *_values.rbegin(); } MONGO_UNREACHABLE_TASSERT(5371401); } +}; -protected: - // Holds all the values in the window, in order, with constant-time access to both ends. - ValueMultiset _values; +template <AccumulatorMinMax::Sense sense> +class WindowFunctionMinMaxN : public WindowFunctionMinMaxCommon<sense> { +public: + using WindowFunctionMinMaxCommon<sense>::_values; + using WindowFunctionMinMaxCommon<sense>::_memUsageBytes; + + static std::unique_ptr<WindowFunctionState> create(ExpressionContext* const expCtx, + long long n) { + return std::make_unique<WindowFunctionMinMaxN<sense>>(expCtx, n); + } + explicit WindowFunctionMinMaxN(ExpressionContext* const expCtx, long long n) + : WindowFunctionMinMaxCommon<sense>(expCtx), _n(n) { + _memUsageBytes = sizeof(*this); + } + + void add(Value value) final { + // Ignore nullish values. + if (value.nullish()) + return; + WindowFunctionMinMaxCommon<sense>::add(std::move(value)); + } + + void remove(Value value) final { + // Ignore nullish values. + if (value.nullish()) + return; + WindowFunctionMinMaxCommon<sense>::remove(std::move(value)); + } + + Value getValue() const final { + if (_values.empty()) { + return Value(std::vector<Value>()); + } + + auto processVal = [&](auto begin, auto end, size_t size) -> Value { + auto n = static_cast<size_t>(_n); + + // If 'n' is greater than the size of the current window, then return all the values. + if (n >= size) { + return Value(std::vector(begin, end)); + } else { + std::vector<Value> result; + result.reserve(n); + auto it = begin; + for (size_t i = 0; i < n; ++i, ++it) { + result.push_back(*it); + } + return Value(std::move(result)); + } + }; + + auto size = _values.size(); + if constexpr (sense == AccumulatorMinMax::Sense::kMin) { + return processVal(_values.begin(), _values.end(), size); + } else { + return processVal(_values.rbegin(), _values.rend(), size); + } + } + + +private: + long long _n; }; using WindowFunctionMin = WindowFunctionMinMax<AccumulatorMinMax::Sense::kMin>; using WindowFunctionMax = WindowFunctionMinMax<AccumulatorMinMax::Sense::kMax>; +using WindowFunctionMinN = WindowFunctionMinMaxN<AccumulatorMinMax::Sense::kMin>; +using WindowFunctionMaxN = WindowFunctionMinMaxN<AccumulatorMinMax::Sense::kMax>; } // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_n_test.cpp b/src/mongo/db/pipeline/window_function/window_function_n_test.cpp new file mode 100644 index 00000000000..246fb42d77d --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_n_test.cpp @@ -0,0 +1,224 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/exec/document_value/document_value_test_util.h" +#include "mongo/db/pipeline/aggregation_context_fixture.h" +#include "mongo/db/pipeline/window_function/window_function_min_max.h" +#include "mongo/db/query/collation/collator_interface_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { +// TODO SERVER-57884: Add test cases for $firstN/$lastN window functions. +// TODO SERVER-57886: Add test cases for $top/$bottom/$topN/$bottomN window functions. +class WindowFunctionMinMaxNTest : public AggregationContextFixture { +public: + static constexpr auto kNarg = 3LL; + WindowFunctionMinMaxNTest() + : expCtx(getExpCtx()), minThree(expCtx.get(), kNarg), maxThree(expCtx.get(), kNarg) { + auto collator = std::make_unique<CollatorInterfaceMock>( + CollatorInterfaceMock::MockType::kToLowerString); + expCtx->setCollator(std::move(collator)); + } + + boost::intrusive_ptr<ExpressionContext> expCtx; + WindowFunctionMinN minThree; + WindowFunctionMaxN maxThree; +}; + +TEST_F(WindowFunctionMinMaxNTest, EmptyWindow) { + auto test = [](auto windowFunction) { + ASSERT_VALUE_EQ(windowFunction.getValue(), Value{BSONArray()}); + + // No matter how many nullish values we insert, we should still get back the empty array. + windowFunction.add(Value()); + windowFunction.add(Value(BSONNULL)); + windowFunction.add(Value()); + windowFunction.add(Value(BSONNULL)); + ASSERT_VALUE_EQ(windowFunction.getValue(), Value{BSONArray()}); + + // Add a value and show that removing nullish has no effect. + windowFunction.add(Value{3}); + ASSERT_VALUE_EQ(windowFunction.getValue(), Value{std::vector{Value(3)}}); + + windowFunction.remove(Value()); + windowFunction.remove(Value(BSONNULL)); + windowFunction.remove(Value()); + windowFunction.remove(Value(BSONNULL)); + ASSERT_VALUE_EQ(windowFunction.getValue(), Value{std::vector{Value(3)}}); + }; + test(minThree); + test(maxThree); +} + +TEST_F(WindowFunctionMinMaxNTest, WindowSmallerThanN) { + minThree.add(Value{5}); + minThree.add(Value{7}); + + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(5), Value(7)})); + + maxThree.add(Value{5}); + maxThree.add(Value{7}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(7), Value(5)})); +} + +TEST_F(WindowFunctionMinMaxNTest, WindowContainsDuplicates) { + minThree.add(Value{5}); + minThree.add(Value{7}); + minThree.add(Value{7}); + minThree.add(Value{7}); + minThree.add(Value{7}); + + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(5), Value(7), Value(7)})); + + maxThree.add(Value{5}); + maxThree.add(Value{5}); + maxThree.add(Value{5}); + maxThree.add(Value{5}); + maxThree.add(Value{5}); + maxThree.add(Value{7}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(7), Value(5), Value(5)})); +} + +TEST_F(WindowFunctionMinMaxNTest, BasicCorrectnessTest) { + minThree.add(Value{5}); + minThree.add(Value{10}); + minThree.add(Value{6}); + minThree.add(Value{12}); + minThree.add(Value{7}); + minThree.add(Value{3}); + + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(5), Value(6)})); + + minThree.remove(Value{5}); + minThree.remove(Value{10}); + + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(6), Value(7)})); + minThree.remove(Value{6}); + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(3), Value(7), Value(12)})); + + minThree.remove(Value{12}); + minThree.remove(Value{7}); + minThree.remove(Value{3}); + ASSERT_VALUE_EQ(minThree.getValue(), Value{BSONArray()}); + + maxThree.add(Value{5}); + maxThree.add(Value{9}); + maxThree.add(Value{12}); + maxThree.add(Value{11}); + maxThree.add(Value{3}); + maxThree.add(Value{7}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(12), Value(11), Value(9)})); + + maxThree.remove(Value{5}); + maxThree.remove(Value{9}); + maxThree.remove(Value{12}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(11), Value(7), Value(3)})); + + maxThree.remove(Value{11}); + maxThree.remove(Value{3}); + maxThree.remove(Value{7}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value{BSONArray()}); +} + +TEST_F(WindowFunctionMinMaxNTest, MixNullsAndNonNulls) { + // Add four values, half of which are null/missing. We should only return the two non-nulls. + minThree.add(Value{4}); + minThree.add(Value()); + minThree.add(Value(BSONNULL)); + minThree.add(Value{1}); + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(1), Value(4)})); + + // Add a couple more values. We should still get no nulls/missing. + minThree.add(Value{3}); + minThree.add(Value()); + minThree.add(Value(BSONNULL)); + minThree.add(Value{2}); + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{Value(1), Value(2), Value(3)})); + + // Add four values, half of which are null/missing. We should only return the two non-nulls. + maxThree.add(Value{4}); + maxThree.add(Value()); + maxThree.add(Value(BSONNULL)); + maxThree.add(Value{1}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(4), Value(1)})); + + // Add a couple more values. We should still get no nulls/missing. + maxThree.add(Value{3}); + maxThree.add(Value()); + maxThree.add(Value(BSONNULL)); + maxThree.add(Value{2}); + ASSERT_VALUE_EQ(maxThree.getValue(), Value(std::vector{Value(4), Value(3), Value(2)})); +} + +TEST_F(WindowFunctionMinMaxNTest, Ties) { + // When two elements tie (compare equal), remove() can't pick an arbitrary one, + // because that would break the invariant that 'add(x); add(y); remove(x)' is equivalent to + // 'add(y)'. + + auto x = Value{"foo"_sd}; + auto y = Value{"FOO"_sd}; + // x and y are distinguishable, + ASSERT_VALUE_NE(x, y); + // but they compare equal according to the ordering. + ASSERT(expCtx->getValueComparator().evaluate(x == y)); + + minThree.add(x); + minThree.add(y); + minThree.remove(x); + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{y})); + + minThree.add(x); + minThree.add(y); + minThree.remove(x); + + // Here, we expect ["foo","FOO"] because we remove the first added entry that compares equal + // to 'x', which is the first instance of 'y'. + ASSERT_VALUE_EQ(minThree.getValue(), Value(std::vector{x, y})); +} + +TEST_F(WindowFunctionMinMaxNTest, TracksMemoryUsageOnAddAndRemove) { + size_t trackingSize = sizeof(WindowFunctionMinN); + ASSERT_EQ(minThree.getApproximateSize(), trackingSize); + + auto largeStr = Value{"$minN/maxN are great window functions"_sd}; + minThree.add(largeStr); + trackingSize += largeStr.getApproximateSize(); + ASSERT_EQ(minThree.getApproximateSize(), trackingSize); + + minThree.add(largeStr); + trackingSize += largeStr.getApproximateSize(); + ASSERT_EQ(minThree.getApproximateSize(), trackingSize); + + minThree.remove(largeStr); + trackingSize -= largeStr.getApproximateSize(); + ASSERT_EQ(minThree.getApproximateSize(), trackingSize); +} +} // namespace +} // namespace mongo |