summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorRui Liu <lriuui0x0@gmail.com>2023-04-29 17:25:29 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-04-29 18:20:43 +0000
commitdf78c930a46ebc670e156387b9afb41b7782aa88 (patch)
tree74f6d8066390e7f25c397d1460d4112a6646da96 /src/mongo
parent7119eeb3c88cd787c686b8fc201a720f1c9e91e4 (diff)
downloadmongo-df78c930a46ebc670e156387b9afb41b7782aa88.tar.gz
SERVER-58070 Implement $topN / $bottomN accumulator
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp9
-rw-r--r--src/mongo/db/exec/sbe/values/value.h7
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp222
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.h71
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.cpp18
-rw-r--r--src/mongo/db/query/sbe_stage_builder.cpp174
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator.cpp276
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator.h40
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp267
9 files changed, 1020 insertions, 64 deletions
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index 010d7d38136..d94494f5097 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -787,6 +787,15 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggFirstNMerge, true}},
{"aggFirstNFinalize",
BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggFirstNFinalize, false}},
+ {"aggTopN", BuiltinFn{[](size_t n) { return n == 3; }, vm::Builtin::aggTopN, true}},
+ {"aggTopNMerge", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggTopNMerge, true}},
+ {"aggTopNFinalize",
+ BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggTopNFinalize, false}},
+ {"aggBottomN", BuiltinFn{[](size_t n) { return n == 3; }, vm::Builtin::aggBottomN, true}},
+ {"aggBottomNMerge",
+ BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggBottomNMerge, true}},
+ {"aggBottomNFinalize",
+ BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggBottomNFinalize, false}},
};
/**
diff --git a/src/mongo/db/exec/sbe/values/value.h b/src/mongo/db/exec/sbe/values/value.h
index 5e01ace708b..787609bb63b 100644
--- a/src/mongo/db/exec/sbe/values/value.h
+++ b/src/mongo/db/exec/sbe/values/value.h
@@ -845,6 +845,13 @@ public:
}
}
+ void pop_back() {
+ if (_vals.size() > 0) {
+ releaseValue(_vals.back().first, _vals.back().second);
+ _vals.pop_back();
+ }
+ }
+
auto size() const noexcept {
return _vals.size();
}
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index 97276df8c71..611f1ceb2d1 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -6039,8 +6039,8 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinArrayToObject(Ar
return {true, objTag, objVal};
}
-std::tuple<value::Array*, size_t, int32_t, int32_t> multiAccState(value::TypeTags accTag,
- value::Value accVal) {
+std::tuple<value::Array*, value::Array*, size_t, int32_t, int32_t> multiAccState(
+ value::TypeTags accTag, value::Value accVal) {
uassert(7548600, "The accumulator state should be an array", accTag == value::TypeTags::Array);
auto acc = value::getArrayView(accVal);
@@ -6069,7 +6069,18 @@ std::tuple<value::Array*, size_t, int32_t, int32_t> multiAccState(value::TypeTag
"MemLimit component should be a 32-bit integer",
memLimitTag == value::TypeTags::NumberInt32);
- return {array, maxSize, memUsage, memLimit};
+ return {acc, array, maxSize, memUsage, memLimit};
+}
+
+void checkAndUpdateMemUsage(value::Array* accArray, int32_t memUsage, int32_t memLimit) {
+ uassert(ErrorCodes::ExceededMemoryLimit,
+ str::stream()
+ << "Accumulator used too much memory and spilling to disk cannot reduce memory "
+ "consumption any further. Memory limit: "
+ << memLimit << " bytes",
+ memUsage < memLimit);
+ accArray->setAt(
+ static_cast<size_t>(AggMultiElems::kMemUsage), value::TypeTags::NumberInt32, memUsage);
}
FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstN(ArityType arity) {
@@ -6078,26 +6089,15 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstN(ArityT
value::ValueGuard accGuard{accTag, accVal};
value::ValueGuard fieldGuard{fieldTag, fieldVal};
- auto [accArr, accSize, memUsage, memLimit] = multiAccState(accTag, accVal);
+ auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
- if (accArr->size() < accSize) {
- // update memusage
+ if (array->size() < maxSize) {
memUsage += value::getApproximateSize(fieldTag, fieldVal);
- auto maxMemAllowed = memLimit;
- uassert(ErrorCodes::ExceededMemoryLimit,
- str::stream()
- << "$firstN used too much memory and spilling to disk cannot reduce memory "
- "consumption any further. Memory limit: "
- << maxMemAllowed << " bytes",
- memUsage < maxMemAllowed);
+ checkAndUpdateMemUsage(acc, memUsage, memLimit);
// add to array
fieldGuard.reset();
- accArr->push_back(fieldTag, fieldVal);
-
- // update the memUsageBytes
- value::getArrayView(accVal)->setAt(
- static_cast<size_t>(AggMultiElems::kMemUsage), value::TypeTags::NumberInt32, memUsage);
+ array->push_back(fieldTag, fieldVal);
}
accGuard.reset();
return {true, accTag, accVal};
@@ -6109,9 +6109,9 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNMerge(A
value::ValueGuard mergeAccGuard{mergeAccTag, mergeAccVal};
value::ValueGuard accGuard{accTag, accVal};
- auto [mergeArr, mergeMaxSize, mergeMemUsage, mergeMemLimit] =
+ auto [mergeAcc, mergeArr, mergeMaxSize, mergeMemUsage, mergeMemLimit] =
multiAccState(mergeAccTag, mergeAccVal);
- auto [arr, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
+ auto [acc, arr, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
uassert(7548604,
"Two arrays to merge should have the same MaxSize component",
maxSize == mergeMaxSize);
@@ -6123,21 +6123,11 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNMerge(A
value::ValueGuard valueGuard{tag, val};
mergeMemUsage += value::getApproximateSize(tag, val);
- auto maxMemAllowed = mergeMemLimit;
- uassert(ErrorCodes::ExceededMemoryLimit,
- str::stream()
- << "$firstN used too much memory and spilling to disk cannot reduce memory "
- "consumption any further. Memory limit: "
- << maxMemAllowed << " bytes",
- mergeMemUsage < maxMemAllowed);
+ checkAndUpdateMemUsage(mergeAcc, mergeMemUsage, mergeMemLimit);
valueGuard.reset();
mergeArr->push_back(tag, val);
}
- value::getArrayView(mergeAccVal)
- ->setAt(static_cast<size_t>(AggMultiElems::kMemUsage),
- value::TypeTags::NumberInt32,
- mergeMemUsage);
}
mergeAccGuard.reset();
@@ -6156,6 +6146,152 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNFinaliz
return {true, outputTag, outputVal};
}
+template <typename Less>
+int32_t aggTopBottomNAdd(value::Array* acc,
+ value::Array* array,
+ size_t maxSize,
+ int32_t memUsage,
+ int32_t memLimit,
+ const value::SortSpec* sortSpec,
+ std::pair<value::TypeTags, value::Value> key,
+ std::pair<value::TypeTags, value::Value> output) {
+ auto newMemUsage = [](int32_t memUsage,
+ std::pair<value::TypeTags, value::Value> key,
+ std::pair<value::TypeTags, value::Value> output) {
+ memUsage += value::getApproximateSize(key.first, key.second);
+ memUsage += value::getApproximateSize(output.first, output.second);
+ return memUsage;
+ };
+
+ value::ValueGuard keyGuard{key.first, key.second};
+ value::ValueGuard outputGuard{output.first, output.second};
+ auto less = Less(sortSpec);
+ auto keyLess = PairKeyComp(less);
+ auto& heap = array->values();
+
+ if (array->size() < maxSize) {
+ auto [pairTag, pairVal] = value::makeNewArray();
+ value::ValueGuard pairGuard{pairTag, pairVal};
+ auto pair = value::getArrayView(pairVal);
+ pair->reserve(2);
+ keyGuard.reset();
+ pair->push_back(key.first, key.second);
+ outputGuard.reset();
+ pair->push_back(output.first, output.second);
+
+ memUsage = newMemUsage(memUsage, key, output);
+ checkAndUpdateMemUsage(acc, memUsage, memLimit);
+
+ pairGuard.reset();
+ array->push_back(pairTag, pairVal);
+ std::push_heap(heap.begin(), heap.end(), keyLess);
+ } else {
+ tassert(5807005,
+ "Heap should contain same number of elements as MaxSize",
+ array->size() == maxSize);
+
+ auto [worstTag, worstVal] = heap.front();
+ auto worst = value::getArrayView(worstVal);
+ auto worstKey = worst->getAt(0);
+ if (less(key, worstKey)) {
+ memUsage = newMemUsage(memUsage, key, output);
+ checkAndUpdateMemUsage(acc, memUsage, memLimit);
+
+ std::pop_heap(heap.begin(), heap.end(), keyLess);
+ keyGuard.reset();
+ worst->setAt(0, key.first, key.second);
+ outputGuard.reset();
+ worst->setAt(1, output.first, output.second);
+ std::push_heap(heap.begin(), heap.end(), keyLess);
+ }
+ }
+
+ return memUsage;
+}
+
+template <typename Less>
+FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomN(ArityType arity) {
+ auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(3);
+ tassert(5807024, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec);
+ auto sortSpec = value::getSortSpecView(sortSpecVal);
+
+ auto [accTag, accVal] = moveOwnedFromStack(0);
+ value::ValueGuard accGuard{accTag, accVal};
+ auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
+ auto key = moveOwnedFromStack(1);
+ auto output = moveOwnedFromStack(2);
+
+ aggTopBottomNAdd<Less>(acc, array, maxSize, memUsage, memLimit, sortSpec, key, output);
+
+ accGuard.reset();
+ return {true, accTag, accVal};
+}
+
+template <typename Less>
+FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomNMerge(
+ ArityType arity) {
+ auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(2);
+ tassert(5807025, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec);
+ auto sortSpec = value::getSortSpecView(sortSpecVal);
+
+ auto [accTag, accVal] = moveOwnedFromStack(1);
+ value::ValueGuard accGuard{accTag, accVal};
+ auto [mergeAccTag, mergeAccVal] = moveOwnedFromStack(0);
+ value::ValueGuard mergeAccGuard{mergeAccTag, mergeAccVal};
+ auto [mergeAcc, mergeArray, mergeMaxSize, mergeMemUsage, mergeMemLimit] =
+ multiAccState(mergeAccTag, mergeAccVal);
+ auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
+ tassert(5807008,
+ "Two arrays to merge should have the same MaxSize component",
+ maxSize == mergeMaxSize);
+
+ for (auto [pairTag, pairVal] : array->values()) {
+ auto pair = value::getArrayView(pairVal);
+ auto key = pair->swapAt(0, value::TypeTags::Null, 0);
+ auto output = pair->swapAt(1, value::TypeTags::Null, 0);
+ mergeMemUsage = aggTopBottomNAdd<Less>(mergeAcc,
+ mergeArray,
+ mergeMaxSize,
+ mergeMemUsage,
+ mergeMemLimit,
+ sortSpec,
+ key,
+ output);
+ }
+
+ mergeAccGuard.reset();
+ return {true, mergeAccTag, mergeAccVal};
+}
+
+FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomNFinalize(
+ ArityType arity) {
+ auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(1);
+ tassert(5807026, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec);
+ auto sortSpec = value::getSortSpecView(sortSpecVal);
+
+ auto [accTag, accVal] = moveOwnedFromStack(0);
+ value::ValueGuard accGuard{accTag, accVal};
+ auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal);
+
+ auto [outputArrayTag, outputArrayVal] = value::makeNewArray();
+ value::ValueGuard outputArrayGuard{outputArrayTag, outputArrayVal};
+ auto outputArray = value::getArrayView(outputArrayVal);
+ outputArray->reserve(array->size());
+
+ // We always output result in the order of sort pattern in according to MQL semantics.
+ auto less = SortPatternLess(sortSpec);
+ auto keyLess = PairKeyComp(less);
+ std::sort(array->values().begin(), array->values().end(), keyLess);
+ for (size_t i = 0; i < array->size(); ++i) {
+ auto pair = value::getArrayView(array->getAt(i).second);
+ auto [outputTag, outputVal] = pair->swapAt(1, value::TypeTags::Null, 0);
+ outputArray->push_back(outputTag, outputVal);
+ }
+
+ outputArrayGuard.reset();
+ return {true, outputArrayTag, outputArrayVal};
+}
+
FastTuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin f,
ArityType arity) {
switch (f) {
@@ -6428,6 +6564,18 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin
return builtinAggFirstNMerge(arity);
case Builtin::aggFirstNFinalize:
return builtinAggFirstNFinalize(arity);
+ case Builtin::aggTopN:
+ return builtinAggTopBottomN<SortPatternLess>(arity);
+ case Builtin::aggTopNMerge:
+ return builtinAggTopBottomNMerge<SortPatternLess>(arity);
+ case Builtin::aggTopNFinalize:
+ return builtinAggTopBottomNFinalize(arity);
+ case Builtin::aggBottomN:
+ return builtinAggTopBottomN<SortPatternGreater>(arity);
+ case Builtin::aggBottomNMerge:
+ return builtinAggTopBottomNMerge<SortPatternGreater>(arity);
+ case Builtin::aggBottomNFinalize:
+ return builtinAggTopBottomNFinalize(arity);
}
MONGO_UNREACHABLE;
@@ -6706,6 +6854,18 @@ std::string builtinToString(Builtin b) {
return "aggFirstNMerge";
case Builtin::aggFirstNFinalize:
return "aggFirstNFinalize";
+ case Builtin::aggTopN:
+ return "aggTopN";
+ case Builtin::aggTopNMerge:
+ return "aggTopNMerge";
+ case Builtin::aggTopNFinalize:
+ return "aggTopNFinalize";
+ case Builtin::aggBottomN:
+ return "aggBottomN";
+ case Builtin::aggBottomNMerge:
+ return "aggBottomNMerge";
+ case Builtin::aggBottomNFinalize:
+ return "aggBottomNFinalize";
default:
MONGO_UNREACHABLE;
}
diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h
index 2deb8b0c907..62b9be9154e 100644
--- a/src/mongo/db/exec/sbe/vm/vm.h
+++ b/src/mongo/db/exec/sbe/vm/vm.h
@@ -37,6 +37,7 @@
#include "mongo/config.h"
#include "mongo/db/exec/sbe/makeobj_spec.h"
#include "mongo/db/exec/sbe/values/slot.h"
+#include "mongo/db/exec/sbe/values/sort_spec.h"
#include "mongo/db/exec/sbe/values/value.h"
#include "mongo/db/exec/sbe/vm/datetime.h"
#include "mongo/db/exec/sbe/vm/label.h"
@@ -757,6 +758,12 @@ enum class Builtin : uint8_t {
aggFirstN,
aggFirstNMerge,
aggFirstNFinalize,
+ aggTopN,
+ aggTopNMerge,
+ aggTopNFinalize,
+ aggBottomN,
+ aggBottomNMerge,
+ aggBottomNFinalize,
};
std::string builtinToString(Builtin b);
@@ -773,6 +780,64 @@ std::string builtinToString(Builtin b);
enum class AggMultiElems { kInternalArr, kMaxSize, kMemUsage, kMemLimit, kSizeOfArray };
/**
+ * Less than comparison based on a sort pattern.
+ */
+struct SortPatternLess {
+ SortPatternLess(const value::SortSpec* sortSpec) : _sortSpec(sortSpec) {}
+
+ bool operator()(const std::pair<value::TypeTags, value::Value>& lhs,
+ const std::pair<value::TypeTags, value::Value>& rhs) const {
+ auto [cmpTag, cmpVal] = _sortSpec->compare(lhs.first, lhs.second, rhs.first, rhs.second);
+ uassert(5807000, "Invalid comparison result", cmpTag == value::TypeTags::NumberInt32);
+ return value::bitcastTo<int32_t>(cmpVal) < 0;
+ }
+
+private:
+ const value::SortSpec* _sortSpec;
+};
+
+/**
+ * Greater than comparison based on a sort pattern.
+ */
+struct SortPatternGreater {
+ SortPatternGreater(const value::SortSpec* sortSpec) : _sortSpec(sortSpec) {}
+
+ bool operator()(const std::pair<value::TypeTags, value::Value>& lhs,
+ const std::pair<value::TypeTags, value::Value>& rhs) const {
+ auto [cmpTag, cmpVal] = _sortSpec->compare(lhs.first, lhs.second, rhs.first, rhs.second);
+ uassert(5807001, "Invalid comparison result", cmpTag == value::TypeTags::NumberInt32);
+ return value::bitcastTo<int32_t>(cmpVal) > 0;
+ }
+
+private:
+ const value::SortSpec* _sortSpec;
+};
+
+/**
+ * Comparison based on the key of a pair of elements.
+ */
+template <typename Comp>
+struct PairKeyComp {
+ PairKeyComp(const Comp& comp) : _comp(comp) {}
+
+ bool operator()(const std::pair<value::TypeTags, value::Value>& lhs,
+ const std::pair<value::TypeTags, value::Value>& rhs) const {
+ auto [lPairTag, lPairVal] = lhs;
+ auto lPair = value::getArrayView(lPairVal);
+ auto lKey = lPair->getAt(0);
+
+ auto [rPairTag, rPairVal] = rhs;
+ auto rPair = value::getArrayView(rPairVal);
+ auto rKey = rPair->getAt(0);
+
+ return _comp(lKey, rKey);
+ }
+
+private:
+ const Comp _comp;
+};
+
+/**
* This enum defines indices into an 'Array' that returns the partial sum result when 'needsMerge'
* is requested.
*
@@ -1586,13 +1651,17 @@ private:
FastTuple<bool, value::TypeTags, value::Value> builtinISOWeekYear(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinISODayOfWeek(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinISOWeek(ArityType arity);
-
FastTuple<bool, value::TypeTags, value::Value> builtinObjectToArray(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinArrayToObject(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstN(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstNMerge(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstNFinalize(ArityType arity);
+ template <typename Less>
+ FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomN(ArityType arity);
+ template <typename Less>
+ FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomNMerge(ArityType arity);
+ FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomNFinalize(ArityType arity);
FastTuple<bool, value::TypeTags, value::Value> dispatchBuiltin(Builtin f, ArityType arity);
diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp
index b4e82dea93c..633304c4903 100644
--- a/src/mongo/db/pipeline/accumulator_multi.cpp
+++ b/src/mongo/db/pipeline/accumulator_multi.cpp
@@ -489,12 +489,13 @@ Document AccumulatorTopBottomN<sense, single>::serialize(
}
template <TopBottomSense sense>
-std::pair<SortPattern, BSONArray> parseAccumulatorTopBottomNSortBy(ExpressionContext* const expCtx,
- BSONObj sortBy) {
+std::tuple<SortPattern, BSONArray, bool> parseAccumulatorTopBottomNSortBy(
+ ExpressionContext* const expCtx, BSONObj sortBy) {
SortPattern sortPattern(sortBy, expCtx);
BSONArrayBuilder sortFieldsExpBab;
BSONObjIterator sortByBoi(sortBy);
+ bool hasMeta = false;
for (const auto& part : sortPattern) {
const auto fieldName = sortByBoi.next().fieldNameStringData();
if (part.expression) {
@@ -505,21 +506,26 @@ std::pair<SortPattern, BSONArray> parseAccumulatorTopBottomNSortBy(ExpressionCon
// sortFields array contains the data we need for sorting.
const auto serialized = part.expression->serialize(false);
sortFieldsExpBab.append(serialized.getDocument().toBson());
+ hasMeta = true;
} else {
sortFieldsExpBab.append((StringBuilder() << "$" << fieldName).str());
}
}
- return {sortPattern, sortFieldsExpBab.arr()};
+ return {sortPattern, sortFieldsExpBab.arr(), hasMeta};
}
template <TopBottomSense sense, bool single>
AccumulationExpression AccumulatorTopBottomN<sense, single>::parseTopBottomN(
ExpressionContext* const expCtx, BSONElement elem, VariablesParseState vps) {
- expCtx->sbeGroupCompatibility = SbeCompatibility::notCompatible;
auto name = AccumulatorTopBottomN<sense, single>::getName();
const auto [n, output, sortBy] =
accumulatorNParseArgs<single>(expCtx, elem, name.rawData(), true, vps);
- auto [sortPattern, sortFieldsExp] = parseAccumulatorTopBottomNSortBy<sense>(expCtx, *sortBy);
+ auto [sortPattern, sortFieldsExp, hasMeta] =
+ parseAccumulatorTopBottomNSortBy<sense>(expCtx, *sortBy);
+
+ auto sbeCompatibility =
+ hasMeta ? SbeCompatibility::notCompatible : SbeCompatibility::flagGuarded;
+ expCtx->sbeGroupCompatibility = std::min(expCtx->sbeGroupCompatibility, sbeCompatibility);
// Construct argument expression. If given sortBy: {field1: 1, field2: 1} it will be shaped like
// {output: <output expression>, sortFields: ["$field1", "$field2"]}. This projects out only the
@@ -539,7 +545,7 @@ template <TopBottomSense sense, bool single>
boost::intrusive_ptr<AccumulatorState> AccumulatorTopBottomN<sense, single>::create(
ExpressionContext* expCtx, BSONObj sortBy, bool isRemovable) {
return make_intrusive<AccumulatorTopBottomN<sense, single>>(
- expCtx, parseAccumulatorTopBottomNSortBy<sense>(expCtx, sortBy).first, isRemovable);
+ expCtx, std::get<0>(parseAccumulatorTopBottomNSortBy<sense>(expCtx, sortBy)), isRemovable);
}
template <TopBottomSense sense, bool single>
diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp
index e823ef7f428..34c6f858fa3 100644
--- a/src/mongo/db/query/sbe_stage_builder.cpp
+++ b/src/mongo/db/query/sbe_stage_builder.cpp
@@ -55,6 +55,7 @@
#include "mongo/db/matcher/expression_leaf.h"
#include "mongo/db/matcher/match_expression_dependencies.h"
#include "mongo/db/pipeline/abt/field_map_builder.h"
+#include "mongo/db/pipeline/accumulator_multi.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_visitor.h"
#include "mongo/db/query/bind_input_params.h"
@@ -2305,6 +2306,45 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>>
}
}
+template <TopBottomSense sense, bool single>
+std::unique_ptr<sbe::EExpression> getSortSpecFromTopBottomN(
+ const AccumulatorTopBottomN<sense, single>* acc) {
+ tassert(5807013, "Accumulator state must not be null", acc);
+ auto sortPattern =
+ acc->getSortPattern().serialize(SortPattern::SortKeySerialization::kForExplain).toBson();
+ auto sortSpec = std::make_unique<sbe::value::SortSpec>(sortPattern);
+ auto sortSpecExpr =
+ makeConstant(sbe::value::TypeTags::sortSpec,
+ sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec.release()));
+ return sortSpecExpr;
+}
+
+std::unique_ptr<sbe::EExpression> getSortSpecFromTopBottomN(const AccumulationStatement& accStmt) {
+ auto acc = accStmt.expr.factory();
+ if (accStmt.expr.name == AccumulatorTopBottomN<kTop, true>::getName()) {
+ return getSortSpecFromTopBottomN(
+ dynamic_cast<AccumulatorTopBottomN<kTop, true>*>(acc.get()));
+ } else if (accStmt.expr.name == AccumulatorTopBottomN<kBottom, true>::getName()) {
+ return getSortSpecFromTopBottomN(
+ dynamic_cast<AccumulatorTopBottomN<kBottom, true>*>(acc.get()));
+ } else if (accStmt.expr.name == AccumulatorTopBottomN<kTop, false>::getName()) {
+ return getSortSpecFromTopBottomN(
+ dynamic_cast<AccumulatorTopBottomN<kTop, false>*>(acc.get()));
+ } else if (accStmt.expr.name == AccumulatorTopBottomN<kBottom, false>::getName()) {
+ return getSortSpecFromTopBottomN(
+ dynamic_cast<AccumulatorTopBottomN<kBottom, false>*>(acc.get()));
+ } else {
+ MONGO_UNREACHABLE;
+ }
+}
+
+bool isTopBottomN(const AccumulationStatement& accStmt) {
+ return accStmt.expr.name == AccumulatorTopBottomN<kTop, true>::getName() ||
+ accStmt.expr.name == AccumulatorTopBottomN<kBottom, true>::getName() ||
+ accStmt.expr.name == AccumulatorTopBottomN<kTop, false>::getName() ||
+ accStmt.expr.name == AccumulatorTopBottomN<kBottom, false>::getName();
+}
+
sbe::value::SlotVector generateAccumulator(
StageBuilderState& state,
const AccumulationStatement& accStmt,
@@ -2313,16 +2353,67 @@ sbe::value::SlotVector generateAccumulator(
sbe::HashAggStage::AggExprVector& aggSlotExprs,
boost::optional<sbe::value::SlotId> initializerRootSlot) {
auto rootSlot = outputs.getIfExists(PlanStageSlots::kResult);
- auto argExpr = generateExpression(state, accStmt.expr.argument.get(), rootSlot, &outputs);
- auto initExpr =
- generateExpression(state, accStmt.expr.initializer.get(), initializerRootSlot, nullptr);
+ auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
// One accumulator may be translated to multiple accumulator expressions. For example, The
// $avg will have two accumulators expressions, a sum(..) and a count which is implemented
// as sum(1).
- auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
- auto accExprs = stage_builder::buildAccumulator(
- accStmt, argExpr.extractExpr(state), collatorSlot, *state.frameIdGenerator);
+ auto accExprs = [&]() {
+ // $topN/$bottomN accumulators require multiple arguments to the accumulator builder.
+ if (isTopBottomN(accStmt)) {
+ StringDataMap<std::unique_ptr<sbe::EExpression>> accArgs;
+ auto sortSpecExpr = getSortSpecFromTopBottomN(accStmt);
+ accArgs.emplace(AccArgs::kTopBottomNSortSpec, sortSpecExpr->clone());
+
+ // Build the key expression for the accumulator.
+ tassert(5807014,
+ str::stream() << accStmt.expr.name
+ << " accumulator must have the root slot set",
+ rootSlot);
+ auto key = collatorSlot ? makeFunction("generateCheapSortKey",
+ std::move(sortSpecExpr),
+ makeVariable(*rootSlot),
+ makeVariable(*collatorSlot))
+ : makeFunction("generateCheapSortKey",
+ std::move(sortSpecExpr),
+ makeVariable(*rootSlot));
+ accArgs.emplace(AccArgs::kTopBottomNKey,
+ makeFunction("sortKeyComponentVectorToArray", std::move(key)));
+
+ // Build the value expression for the accumulator.
+ auto expObj = dynamic_cast<ExpressionObject*>(accStmt.expr.argument.get());
+ tassert(5807015,
+ str::stream() << accStmt.expr.name
+ << " accumulator must have an object argument",
+ expObj);
+ for (auto& [key, value] : expObj->getChildExpressions()) {
+ if (key == AccumulatorN::kFieldNameOutput) {
+ auto outputExpr = generateExpression(state, value.get(), rootSlot, &outputs);
+ accArgs.emplace(AccArgs::kTopBottomNValue,
+ makeFillEmptyNull(outputExpr.extractExpr(state)));
+ break;
+ }
+ }
+ tassert(5807016,
+ str::stream() << accStmt.expr.name
+ << " accumulator must have an output field in the argument",
+ accArgs.find(AccArgs::kTopBottomNValue) != accArgs.end());
+
+ auto accExprs = stage_builder::buildAccumulator(
+ accStmt, std::move(accArgs), collatorSlot, *state.frameIdGenerator);
+
+ return accExprs;
+ } else {
+ auto argExpr =
+ generateExpression(state, accStmt.expr.argument.get(), rootSlot, &outputs);
+ auto accExprs = stage_builder::buildAccumulator(
+ accStmt, argExpr.extractExpr(state), collatorSlot, *state.frameIdGenerator);
+ return accExprs;
+ }
+ }();
+
+ auto initExpr =
+ generateExpression(state, accStmt.expr.initializer.get(), initializerRootSlot, nullptr);
auto accInitExprs = stage_builder::buildInitialize(
accStmt, initExpr.extractExpr(state), *state.frameIdGenerator);
@@ -2362,8 +2453,18 @@ sbe::SlotExprPairVector generateMergingExpressions(StageBuilderState& state,
auto spillSlots = slotIdGenerator->generateMultiple(numInputSlots);
auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
- auto mergingExprs =
- buildCombinePartialAggregates(accStmt, spillSlots, collatorSlot, *frameIdGenerator);
+
+ auto mergingExprs = [&]() {
+ if (isTopBottomN(accStmt)) {
+ StringDataMap<std::unique_ptr<sbe::EExpression>> mergeArgs;
+ mergeArgs.emplace(AccArgs::kTopBottomNSortSpec, getSortSpecFromTopBottomN(accStmt));
+ return buildCombinePartialAggregates(
+ accStmt, spillSlots, std::move(mergeArgs), collatorSlot, *frameIdGenerator);
+ } else {
+ return buildCombinePartialAggregates(
+ accStmt, spillSlots, collatorSlot, *frameIdGenerator);
+ }
+ }();
// Zip the slot vector and expression vector into a vector of pairs.
tassert(7039550,
@@ -2409,13 +2510,33 @@ std::tuple<std::vector<std::string>, sbe::value::SlotVector, EvalStage> generate
}
}();
+ auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd);
auto finalSlots{sbe::value::SlotVector{finalGroupBySlot}};
std::vector<std::string> fieldNames{"_id"};
size_t idxAccFirstSlot = dedupedGroupBySlots.size();
for (size_t idxAcc = 0; idxAcc < accStmts.size(); ++idxAcc) {
// Gathers field names for the output object from accumulator statements.
fieldNames.push_back(accStmts[idxAcc].fieldName);
- auto finalExpr = stage_builder::buildFinalize(state, accStmts[idxAcc], aggSlotsVec[idxAcc]);
+
+ auto finalExpr = [&]() {
+ const auto& accStmt = accStmts[idxAcc];
+ if (isTopBottomN(accStmt)) {
+ StringDataMap<std::unique_ptr<sbe::EExpression>> finalArgs;
+ finalArgs.emplace(AccArgs::kTopBottomNSortSpec, getSortSpecFromTopBottomN(accStmt));
+ return buildFinalize(state,
+ accStmts[idxAcc],
+ aggSlotsVec[idxAcc],
+ std::move(finalArgs),
+ collatorSlot,
+ *state.frameIdGenerator);
+ } else {
+ return buildFinalize(state,
+ accStmts[idxAcc],
+ aggSlotsVec[idxAcc],
+ collatorSlot,
+ *state.frameIdGenerator);
+ }
+ }();
// The final step may not return an expression if it's trivial. For example, $first and
// $last's final steps are trivial.
@@ -2512,6 +2633,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
if (!groupNode->needWholeDocument) {
// Tracks whether we need to request kResult.
bool rootDocIsNeeded = false;
+ bool sortKeyIsNeeded = false;
auto referencesRoot = [&](const ExpressionFieldPath* fieldExpr) {
rootDocIsNeeded = rootDocIsNeeded || fieldExpr->isROOT();
};
@@ -2520,23 +2642,29 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder
walkAndActOnFieldPaths(idExpr.get(), referencesRoot);
for (const auto& accStmt : accStmts) {
walkAndActOnFieldPaths(accStmt.expr.argument.get(), referencesRoot);
+ if (isTopBottomN(accStmt)) {
+ sortKeyIsNeeded = true;
+ }
}
- // If the group node doesn't have any dependency (e.g. $count) or if the dependency can be
- // satisfied by the child node (e.g. covered index scan), we can clear the kResult
- // requirement for the child.
- if (groupNode->requiredFields.empty() || !rootDocIsNeeded) {
- childReqs.clear(kResult);
- } else if (childNode->getType() == StageType::STAGE_PROJECTION_COVERED) {
- auto pn = static_cast<const ProjectionNodeCovered*>(childNode);
- std::set<std::string> providedFieldSet;
- for (auto&& elt : pn->coveredKeyObj) {
- providedFieldSet.emplace(elt.fieldNameStringData());
- }
- if (std::all_of(groupNode->requiredFields.begin(),
- groupNode->requiredFields.end(),
- [&](const std::string& f) { return providedFieldSet.count(f); })) {
+ // If any accumulator requires generating sort key, we cannot clear the kResult.
+ if (!sortKeyIsNeeded) {
+ // If the group node doesn't have any dependency (e.g. $count) or if the dependency can
+ // be satisfied by the child node (e.g. covered index scan), we can clear the kResult
+ // requirement for the child.
+ if (groupNode->requiredFields.empty() || !rootDocIsNeeded) {
childReqs.clear(kResult);
+ } else if (childNode->getType() == StageType::STAGE_PROJECTION_COVERED) {
+ auto pn = static_cast<const ProjectionNodeCovered*>(childNode);
+ std::set<std::string> providedFieldSet;
+ for (auto&& elt : pn->coveredKeyObj) {
+ providedFieldSet.emplace(elt.fieldNameStringData());
+ }
+ if (std::all_of(groupNode->requiredFields.begin(),
+ groupNode->requiredFields.end(),
+ [&](const std::string& f) { return providedFieldSet.count(f); })) {
+ childReqs.clear(kResult);
+ }
}
}
}
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
index 5f68a8dc3e0..21f3457f9e7 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp
@@ -584,6 +584,7 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMergeObjec
std::vector<std::unique_ptr<sbe::EExpression>> buildInitializeAccumulatorMulti(
std::unique_ptr<sbe::EExpression> maxSizeExpr, sbe::value::FrameIdGenerator& frameIdGenerator) {
+ // Create an array of four elements [value holder, max size, memory used, memory limit].
std::vector<std::unique_ptr<sbe::EExpression>> aggs;
auto maxAccumulatorBytes = internalQueryTopNAccumulatorBytes.load();
if (auto* maxSizeConstExpr = maxSizeExpr->as<sbe::EConstant>()) {
@@ -663,12 +664,170 @@ std::unique_ptr<sbe::EExpression> buildFinalizeFirstN(StageBuilderState& state,
return makeFunction("aggFirstNFinalize", makeVariable(inputSlots[0]));
}
+bool isAccumulatorTopN(const AccumulationExpression& expr) {
+ return expr.name == AccumulatorTopBottomN<kTop, false /* single */>::getName() ||
+ expr.name == AccumulatorTopBottomN<kTop, true /* single */>::getName();
+}
+
+std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorTopBottomN(
+ const AccumulationExpression& expr,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> args,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ auto it = args.find(AccArgs::kTopBottomNKey);
+ tassert(5807009,
+ str::stream() << "Accumulator " << expr.name << " expects a '"
+ << AccArgs::kTopBottomNKey << "' argument",
+ it != args.end());
+ auto key = std::move(it->second);
+
+ it = args.find(AccArgs::kTopBottomNValue);
+ tassert(5807010,
+ str::stream() << "Accumulator " << expr.name << " expects a '"
+ << AccArgs::kTopBottomNValue << "' argument",
+ it != args.end());
+ auto value = std::move(it->second);
+
+ it = args.find(AccArgs::kTopBottomNSortSpec);
+ tassert(5807021,
+ str::stream() << "Accumulator " << expr.name << " expects a '"
+ << AccArgs::kTopBottomNSortSpec << "' argument",
+ it != args.end());
+ auto sortSpec = std::move(it->second);
+
+ std::vector<std::unique_ptr<sbe::EExpression>> aggs;
+ aggs.push_back(makeFunction(isAccumulatorTopN(expr) ? "aggTopN" : "aggBottomN",
+ std::move(key),
+ std::move(value),
+ std::move(sortSpec)));
+ return aggs;
+}
+
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialTopBottomN(
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> args,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ tassert(5807011,
+ str::stream() << "Expected one input slot for merging " << expr.name
+ << ", got: " << inputSlots.size(),
+ inputSlots.size() == 1);
+
+ auto it = args.find(AccArgs::kTopBottomNSortSpec);
+ tassert(5807022,
+ str::stream() << "Accumulator " << expr.name << " expects a '"
+ << AccArgs::kTopBottomNSortSpec << "' argument",
+ it != args.end());
+ auto sortSpec = std::move(it->second);
+
+ std::vector<std::unique_ptr<sbe::EExpression>> aggs;
+ aggs.push_back(makeFunction(isAccumulatorTopN(expr) ? "aggTopNMerge" : "aggBottomNMerge",
+ makeVariable(inputSlots[0]),
+ std::move(sortSpec)));
+ return aggs;
+}
+
+std::unique_ptr<sbe::EExpression> buildFinalizeTopBottomNImpl(
+ StageBuilderState& state,
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> args,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator,
+ bool single) {
+ tassert(5807012,
+ str::stream() << "Expected one input slot for finalization of " << expr.name
+ << ", got: " << inputSlots.size(),
+ inputSlots.size() == 1);
+ auto inputVar = makeVariable(inputSlots[0]);
+
+ auto it = args.find(AccArgs::kTopBottomNSortSpec);
+ tassert(5807023,
+ str::stream() << "Accumulator " << expr.name << " expects a '"
+ << AccArgs::kTopBottomNSortSpec << "' argument",
+ it != args.end());
+ auto sortSpec = std::move(it->second);
+
+ if (state.needsMerge) {
+ // When the data will be merged, the heap itself doesn't need to be sorted since the merging
+ // code will handle the sorting.
+ auto heapExpr =
+ makeFunction("getElement",
+ inputVar->clone(),
+ makeConstant(sbe::value::TypeTags::NumberInt32,
+ static_cast<int>(sbe::vm::AggMultiElems::kInternalArr)));
+ auto lambdaFrameId = frameIdGenerator.generate();
+ auto pairVar = makeVariable(lambdaFrameId, 0);
+ auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>(
+ lambdaFrameId,
+ makeNewObjFunction(
+ FieldPair{AccumulatorN::kFieldNameGeneratedSortKey,
+ makeFunction("getElement",
+ pairVar->clone(),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 0))},
+ FieldPair{AccumulatorN::kFieldNameOutput,
+ makeFunction("getElement",
+ pairVar->clone(),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 1))}));
+ // Convert the array pair representation [key, output] to an object format that the merging
+ // code expects.
+ return makeFunction("traverseP",
+ std::move(heapExpr),
+ std::move(lambdaExpr),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 1));
+ } else {
+ auto finalExpr =
+ makeFunction(isAccumulatorTopN(expr) ? "aggTopNFinalize" : "aggBottomNFinalize",
+ inputVar->clone(),
+ std::move(sortSpec));
+ if (single) {
+ finalExpr = makeFunction("getElement",
+ std::move(finalExpr),
+ makeConstant(sbe::value::TypeTags::NumberInt32, 0));
+ }
+ return finalExpr;
+ }
+}
+
+std::unique_ptr<sbe::EExpression> buildFinalizeTopBottomN(
+ StageBuilderState& state,
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> args,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ return buildFinalizeTopBottomNImpl(state,
+ expr,
+ inputSlots,
+ std::move(args),
+ collatorSlot,
+ frameIdGenerator,
+ false /* single */);
+}
+
+std::unique_ptr<sbe::EExpression> buildFinalizeTopBottom(
+ StageBuilderState& state,
+ const AccumulationExpression& expr,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> args,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ return buildFinalizeTopBottomNImpl(state,
+ expr,
+ inputSlots,
+ std::move(args),
+ collatorSlot,
+ frameIdGenerator,
+ true /* single */);
+}
+
template <int N>
std::vector<std::unique_ptr<sbe::EExpression>> emptyInitializer(
std::unique_ptr<sbe::EExpression> maxSizeExpr, sbe::value::FrameIdGenerator& frameIdGenerator) {
return std::vector<std::unique_ptr<sbe::EExpression>>{N};
}
-}; // namespace
+} // namespace
std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
const AccumulationStatement& acc,
@@ -708,6 +867,37 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
frameIdGenerator);
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
+ const AccumulationStatement& acc,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ using BuildAccumulatorFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>(
+ const AccumulationExpression&,
+ StringDataMap<std::unique_ptr<sbe::EExpression>>,
+ boost::optional<sbe::value::SlotId>,
+ sbe::value::FrameIdGenerator&)>;
+
+ static const StringDataMap<BuildAccumulatorFn> kAccumulatorBuilders = {
+ {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildAccumulatorTopBottomN},
+ {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), &buildAccumulatorTopBottomN},
+ {AccumulatorTopBottomN<kTop, false /* single */>::getName(), &buildAccumulatorTopBottomN},
+ {AccumulatorTopBottomN<kBottom, false /* single */>::getName(),
+ &buildAccumulatorTopBottomN},
+ };
+
+ auto accExprName = acc.expr.name;
+ uassert(5807017,
+ str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName,
+ kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end());
+
+ return std::invoke(kAccumulatorBuilders.at(accExprName),
+ acc.expr,
+ std::move(argExprs),
+ collatorSlot,
+ frameIdGenerator);
+}
+
std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
const AccumulationStatement& acc,
const sbe::value::SlotVector& inputSlots,
@@ -743,9 +933,47 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
kAggCombinerBuilders.at(accExprName), acc.expr, inputSlots, collatorSlot, frameIdGenerator);
}
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ using BuildAggCombinerFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>(
+ const AccumulationExpression&,
+ const sbe::value::SlotVector&,
+ StringDataMap<std::unique_ptr<sbe::EExpression>>,
+ boost::optional<sbe::value::SlotId>,
+ sbe::value::FrameIdGenerator&)>;
+
+ static const StringDataMap<BuildAggCombinerFn> kAggCombinerBuilders = {
+ {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildCombinePartialTopBottomN},
+ {AccumulatorTopBottomN<kBottom, true /* single */>::getName(),
+ &buildCombinePartialTopBottomN},
+ {AccumulatorTopBottomN<kTop, false /* single */>::getName(),
+ &buildCombinePartialTopBottomN},
+ {AccumulatorTopBottomN<kBottom, false /* single */>::getName(),
+ &buildCombinePartialTopBottomN},
+ };
+
+ auto accExprName = acc.expr.name;
+ uassert(5807019,
+ str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName,
+ kAggCombinerBuilders.find(accExprName) != kAggCombinerBuilders.end());
+
+ return std::invoke(kAggCombinerBuilders.at(accExprName),
+ acc.expr,
+ inputSlots,
+ std::move(argExprs),
+ collatorSlot,
+ frameIdGenerator);
+}
+
std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state,
const AccumulationStatement& acc,
- const sbe::value::SlotVector& aggSlots) {
+ const sbe::value::SlotVector& aggSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
using BuildFinalizeFn = std::function<std::unique_ptr<sbe::EExpression>(
StageBuilderState&, const AccumulationExpression&, sbe::value::SlotVector)>;
@@ -777,6 +1005,42 @@ std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state,
}
}
+std::unique_ptr<sbe::EExpression> buildFinalize(
+ StageBuilderState& state,
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& aggSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator) {
+ using BuildFinalizeFn = std::function<std::unique_ptr<sbe::EExpression>(
+ StageBuilderState&,
+ const AccumulationExpression&,
+ sbe::value::SlotVector,
+ StringDataMap<std::unique_ptr<sbe::EExpression>>,
+ boost::optional<sbe::value::SlotId>,
+ sbe::value::FrameIdGenerator&)>;
+
+ static const StringDataMap<BuildFinalizeFn> kAccumulatorBuilders = {
+ {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildFinalizeTopBottom},
+ {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), &buildFinalizeTopBottom},
+ {AccumulatorTopBottomN<kTop, false /* single */>::getName(), &buildFinalizeTopBottomN},
+ {AccumulatorTopBottomN<kBottom, false /* single */>::getName(), &buildFinalizeTopBottomN},
+ };
+
+ auto accExprName = acc.expr.name;
+ uassert(5807020,
+ str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName,
+ kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end());
+
+ return std::invoke(kAccumulatorBuilders.at(accExprName),
+ state,
+ acc.expr,
+ aggSlots,
+ std::move(argExprs),
+ collatorSlot,
+ frameIdGenerator);
+}
+
std::vector<std::unique_ptr<sbe::EExpression>> buildInitialize(
const AccumulationStatement& acc,
std::unique_ptr<sbe::EExpression> initExpr,
@@ -797,6 +1061,14 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildInitialize(
{AccumulatorStdDevPop::kName, &emptyInitializer<1>},
{AccumulatorStdDevSamp::kName, &emptyInitializer<1>},
{AccumulatorFirstN::kName, &buildInitializeAccumulatorMulti},
+ {AccumulatorTopBottomN<kTop, true /* single */>::getName(),
+ &buildInitializeAccumulatorMulti},
+ {AccumulatorTopBottomN<kBottom, true /* single */>::getName(),
+ &buildInitializeAccumulatorMulti},
+ {AccumulatorTopBottomN<kTop, false /* single */>::getName(),
+ &buildInitializeAccumulatorMulti},
+ {AccumulatorTopBottomN<kBottom, false /* single */>::getName(),
+ &buildInitializeAccumulatorMulti},
};
auto accExprName = acc.expr.name;
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.h b/src/mongo/db/query/sbe_stage_builder_accumulator.h
index 58daf6cd866..33d9008302b 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator.h
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator.h
@@ -39,6 +39,12 @@
namespace mongo::stage_builder {
class PlanStageSlots;
+namespace AccArgs {
+const StringData kTopBottomNSortSpec = "sortSpec"_sd;
+const StringData kTopBottomNKey = "key"_sd;
+const StringData kTopBottomNValue = "value"_sd;
+} // namespace AccArgs
+
/**
* Translates an input AccumulationStatement into an SBE EExpression for accumulation expressions.
*/
@@ -49,6 +55,15 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
sbe::value::FrameIdGenerator&);
/**
+ * Similar to above but takes multiple arguments.
+ */
+std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator(
+ const AccumulationStatement& acc,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator&);
+
+/**
* When SBE hash aggregation spills to disk, it spills partial aggregates which need to be combined
* later. This function returns the expressions that can be used to combine partial aggregates for
* the given accumulator 'acc'. The aggregate-of-aggregates will be stored in a slots owned by the
@@ -62,13 +77,36 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
sbe::value::FrameIdGenerator&);
/**
+ * Similar to above but takes multiple arguments.
+ */
+std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates(
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& inputSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator&);
+
+/**
* Translates an input AccumulationStatement into an SBE EExpression that represents an
* AccumulationStatement's finalization step. The 'stage' parameter provides the input subtree to
* build on top of.
*/
std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state,
const AccumulationStatement& acc,
- const sbe::value::SlotVector& aggSlots);
+ const sbe::value::SlotVector& aggSlots,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator);
+
+/**
+ * Similar to above but takes multiple arguments.
+ */
+std::unique_ptr<sbe::EExpression> buildFinalize(
+ StageBuilderState& state,
+ const AccumulationStatement& acc,
+ const sbe::value::SlotVector& aggSlots,
+ StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs,
+ boost::optional<sbe::value::SlotId> collatorSlot,
+ sbe::value::FrameIdGenerator& frameIdGenerator);
/**
* Translates an input AccumulationStatement into an SBE EExpression for the initialization of the
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
index d5b52d3489e..b4ca88e919a 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
@@ -33,6 +33,7 @@
#include "mongo/bson/bsonobj.h"
#include "mongo/db/exec/sbe/expression_test_base.h"
+#include "mongo/db/exec/sbe/vm/vm.h"
#include "mongo/db/pipeline/document_source_group.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/collation/collator_interface_mock.h"
@@ -1678,6 +1679,179 @@ TEST_F(SbeStageBuilderGroupTest, FirstNAccumulatorInvalidDynamicN) {
static_cast<ErrorCodes::Error>(7548607));
}
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorSingleGroup) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4)),
+ BSON_ARRAY(BSON("a" << 33 << "s" << 3)),
+ BSON_ARRAY(BSON("a" << 22 << "s" << 2)),
+ BSON_ARRAY(BSON("a" << 11 << "s" << 1))};
+ runGroupAggregationTest("{_id: null, x: {$top: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 11)));
+ runGroupAggregationTest("{_id: null, x: {$bottom: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 44)));
+ runGroupAggregationTest(
+ "{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22 << 33))));
+ runGroupAggregationTest(
+ "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(22 << 33 << 44))));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorCompoundSort) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s1" << 1 << "s2" << 1)),
+ BSON_ARRAY(BSON("a" << 12 << "s1" << 1 << "s2" << 2)),
+ BSON_ARRAY(BSON("a" << 21 << "s1" << 2 << "s2" << 1)),
+ BSON_ARRAY(BSON("a" << 22 << "s1" << 2 << "s2" << 2))};
+ runGroupAggregationTest("{_id: null, x: {$top: {output: '$a', sortBy: {s1: 1, s2: -1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 12)));
+ runGroupAggregationTest("{_id: null, x: {$bottom: {output: '$a', sortBy: {s1: 1, s2: -1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 21)));
+ runGroupAggregationTest(
+ "{_id: null, x: {$topN: {output: '$a', sortBy: {s1: 1, s2: -1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(12 << 11 << 22))));
+ runGroupAggregationTest(
+ "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s1: 1, s2: -1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22 << 21))));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorCollation) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 41 << "s"
+ << "41")),
+ BSON_ARRAY(BSON("a" << 32 << "s"
+ << "32")),
+ BSON_ARRAY(BSON("a" << 23 << "s"
+ << "23")),
+ BSON_ARRAY(BSON("a" << 14 << "s"
+ << "14"))};
+ runGroupAggregationTest(
+ "{_id: null, x: {$top: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 41)),
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString));
+ runGroupAggregationTest(
+ "{_id: null, x: {$bottom: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 14)),
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString));
+ runGroupAggregationTest(
+ "{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(41 << 32 << 23))),
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString));
+ runGroupAggregationTest(
+ "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(32 << 23 << 14))),
+ std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorNotEnoughElement) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 22 << "s" << 2)),
+ BSON_ARRAY(BSON("a" << 11 << "s" << 1))};
+ runGroupAggregationTest("{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22))));
+ runGroupAggregationTest("{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22))));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorMultiGroup) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4 << "n" << 1)),
+ BSON_ARRAY(BSON("a" << 33 << "s" << 3 << "n" << 1)),
+ BSON_ARRAY(BSON("a" << 22 << "s" << 2 << "n" << 1)),
+ BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n" << 1)),
+ BSON_ARRAY(BSON("a" << 88 << "s" << 8 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 77 << "s" << 7 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 66 << "s" << 6 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 55 << "s" << 5 << "n" << 2))};
+ runGroupAggregationTest(
+ "{_id: '$n', x: {$top: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << 1 << "x" << 11) << BSON("_id" << 2 << "x" << 55)));
+ runGroupAggregationTest(
+ "{_id: '$n', x: {$bottom: {output: '$a', sortBy: {s: 1}}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << 1 << "x" << 44) << BSON("_id" << 2 << "x" << 88)));
+ runGroupAggregationTest("{_id: '$n', x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << 1 << "x" << BSON_ARRAY(11 << 22 << 33))
+ << BSON("_id" << 2 << "x" << BSON_ARRAY(55 << 66 << 77))));
+ runGroupAggregationTest("{_id: '$n', x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << 1 << "x" << BSON_ARRAY(22 << 33 << 44))
+ << BSON("_id" << 2 << "x" << BSON_ARRAY(66 << 77 << 88))));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorDynamicN) {
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 33 << "s" << 3 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 22 << "s" << 2 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n" << 2)),
+ BSON_ARRAY(BSON("a" << 88 << "s" << 8 << "n" << 3)),
+ BSON_ARRAY(BSON("a" << 77 << "s" << 7 << "n" << 3)),
+ BSON_ARRAY(BSON("a" << 66 << "s" << 6 << "n" << 3)),
+ BSON_ARRAY(BSON("a" << 55 << "s" << 5 << "n" << 3))};
+ runGroupAggregationTest(
+ "{_id: {n1: '$n'}, x: {$topN: {output: '$a', sortBy: {s: 1}, n: '$n1'}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSON("n1" << 2) << "x" << BSON_ARRAY(11 << 22))
+ << BSON("_id" << BSON("n1" << 3) << "x" << BSON_ARRAY(55 << 66 << 77))));
+ runGroupAggregationTest(
+ "{_id: {n1: '$n'}, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: '$n1'}}}",
+ docs,
+ BSON_ARRAY(BSON("_id" << BSON("n1" << 2) << "x" << BSON_ARRAY(33 << 44))
+ << BSON("_id" << BSON("n1" << 3) << "x" << BSON_ARRAY(66 << 77 << 88))));
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorInvalidConstantN) {
+ const std::vector<std::string> accumulators{"$topN", "$bottomN"};
+ const std::vector<std::string> testCases{"'string'", "4.2", "-1", "0"};
+ auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s" << 1))};
+ for (const auto& acc : accumulators) {
+ for (const auto& testCase : testCases) {
+ runGroupAggregationToFail(str::stream() << "{_id: null, x: {" << acc
+ << ": {output: '$a', sortBy: {s: 1}, n: "
+ << testCase << "}}}",
+ docs,
+ static_cast<ErrorCodes::Error>(7548606));
+ }
+ }
+}
+
+TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorInvalidDynamicN) {
+ const std::vector<std::string> accumulators{"$topN", "$bottomN"};
+ const std::vector<BSONObj> testCases{BSON("n"
+ << "string"),
+ BSON("n" << 4.2),
+ BSON("n" << -1),
+ BSON("n" << 0)};
+ for (const auto& acc : accumulators) {
+ for (const auto& testCase : testCases) {
+ auto docs =
+ std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n1" << testCase))};
+
+ runGroupAggregationToFail(str::stream()
+ << "{_id: null, x: {" << acc
+ << ": {output: '$a', sortBy: {s: 1}, n: '$n'}}}",
+ docs,
+ static_cast<ErrorCodes::Error>(7548607));
+ runGroupAggregationToFail(str::stream()
+ << "{_id: {n: '$n1.n'}, x: {" << acc
+ << ": {output: '$a', sortBy: {s: 1}, n: '$n'}}}",
+ docs,
+ static_cast<ErrorCodes::Error>(7548607));
+ }
+ }
+}
+
class AccumulatorSBEIncompatible final : public AccumulatorState {
public:
static constexpr auto kName = "$incompatible"_sd;
@@ -1947,6 +2121,45 @@ public:
return {resultTag, resultVal};
}
+ std::pair<sbe::value::TypeTags, sbe::value::Value> bsonArrayToSbe(BSONArray arr) {
+ auto [arrTag, arrVal] = sbe::value::makeNewArray();
+ auto arrView = sbe::value::getArrayView(arrVal);
+
+ for (auto elem : arr) {
+ auto [tag, val] = sbe::bson::convertFrom<false>(elem);
+ arrView->push_back(tag, val);
+ }
+ return {arrTag, arrVal};
+ }
+
+ /**
+ * Create an accumulator state for $topN/$bottomN, given heap in the format of BSONArray.
+ */
+ std::pair<sbe::value::TypeTags, sbe::value::Value> makeTopBottomNAccumulatorState(
+ BSONArray valuesBson, long maxSize, const sbe::value::SortSpec* sortSpec) {
+ auto [stateTag, stateVal] = sbe::value::makeNewArray();
+ sbe::value::ValueGuard stateGuard{stateTag, stateVal};
+ auto state = sbe::value::getArrayView(stateVal);
+
+ auto [valuesTag, valuesVal] = bsonArrayToSbe(valuesBson);
+ sbe::value::ValueGuard valuesGuard{valuesTag, valuesVal};
+ // Heap
+ state->push_back(valuesTag, valuesVal);
+
+ // Max size
+ state->push_back(sbe::value::TypeTags::NumberInt64, maxSize);
+
+ // Memory usage
+ state->push_back(sbe::value::TypeTags::NumberInt32, 0);
+
+ // Memory limit
+ state->push_back(sbe::value::TypeTags::NumberInt32, INT_MAX);
+
+ valuesGuard.reset();
+ stateGuard.reset();
+ return {stateTag, stateVal};
+ }
+
/**
* Given the name of an SBE agg function ('aggFuncName') and an array of values expressed as a
* BSON array, aggregates the values inside the array and returns the resulting SBE value.
@@ -2621,4 +2834,58 @@ TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsFirstNInputArrayEm
ASSERT_EQ(compareVal, 0);
sbe::value::releaseValue(resultTag, resultVal);
}
+
+TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsTopBottomN) {
+ auto sortPattern = BSON("x" << 1);
+ auto sortSpec = new sbe::value::SortSpec(sortPattern);
+ auto sortSpecConstant = stage_builder::makeConstant(
+ sbe::value::TypeTags::sortSpec, sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec));
+
+ auto topNExpr = stage_builder::makeFunction(
+ "aggTopNMerge", stage_builder::makeVariable(_inputSlotId), sortSpecConstant->clone());
+ auto bottomNExpr = stage_builder::makeFunction(
+ "aggBottomNMerge", stage_builder::makeVariable(_inputSlotId), sortSpecConstant->clone());
+
+ auto aggSlot = bindAccessor(&_aggAccessor);
+ auto topNFinalExpr = stage_builder::makeFunction(
+ "aggTopNFinalize", stage_builder::makeVariable(aggSlot), sortSpecConstant->clone());
+ auto bottomNFinalExpr = stage_builder::makeFunction(
+ "aggBottomNFinalize", stage_builder::makeVariable(aggSlot), sortSpecConstant->clone());
+
+ std::vector<std::tuple<sbe::EExpression*, sbe::EExpression*, BSONArray, BSONArray, BSONArray>>
+ testCases{{topNExpr.get(),
+ topNFinalExpr.get(),
+ BSON_ARRAY(BSON_ARRAY(5 << 5) << BSON_ARRAY(3 << 3) << BSON_ARRAY(1 << 1)),
+ BSON_ARRAY(BSON_ARRAY(6 << 6) << BSON_ARRAY(4 << 4) << BSON_ARRAY(2 << 2)),
+ BSON_ARRAY(1 << 2 << 3)},
+ {bottomNExpr.get(),
+ bottomNFinalExpr.get(),
+ BSON_ARRAY(BSON_ARRAY(1 << 1) << BSON_ARRAY(3 << 3) << BSON_ARRAY(5 << 5)),
+ BSON_ARRAY(BSON_ARRAY(2 << 2) << BSON_ARRAY(4 << 4) << BSON_ARRAY(6 << 6)),
+ BSON_ARRAY(4 << 5 << 6)}};
+
+ for (auto& [expr, finalExpr, heapMerge, heapIncoming, expected] : testCases) {
+ auto [accTag, accVal] = makeTopBottomNAccumulatorState(heapMerge, 3, sortSpec);
+ _aggAccessor.reset(true, accTag, accVal);
+
+ auto [inputTag, inputVal] = makeTopBottomNAccumulatorState(heapIncoming, 3, sortSpec);
+ _inputAccessor.reset(true, inputTag, inputVal);
+
+ auto compiledExpr = compileAggExpression(*expr, &_aggAccessor);
+
+ auto [newAccTag, newAccVal] = runCompiledExpression(compiledExpr.get());
+ _aggAccessor.reset(true, newAccTag, newAccVal);
+
+ auto compiledFinalExpr = compileExpression(*finalExpr);
+
+ auto [resultTag, resultVal] = runCompiledExpression(compiledFinalExpr.get());
+
+ auto [expectedTag, expectedVal] = bsonArrayToSbe(expected);
+ auto [compareTag, compareVal] =
+ sbe::value::compareValue(resultTag, resultVal, expectedTag, expectedVal);
+
+ ASSERT_EQ(compareTag, sbe::value::TypeTags::NumberInt32);
+ ASSERT_EQ(compareVal, 0);
+ }
+}
} // namespace mongo