diff options
author | Rui Liu <lriuui0x0@gmail.com> | 2023-04-29 17:25:29 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2023-04-29 18:20:43 +0000 |
commit | df78c930a46ebc670e156387b9afb41b7782aa88 (patch) | |
tree | 74f6d8066390e7f25c397d1460d4112a6646da96 /src | |
parent | 7119eeb3c88cd787c686b8fc201a720f1c9e91e4 (diff) | |
download | mongo-df78c930a46ebc670e156387b9afb41b7782aa88.tar.gz |
SERVER-58070 Implement $topN / $bottomN accumulator
Diffstat (limited to 'src')
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/expression.cpp | 9 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/values/value.h | 7 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.cpp | 222 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.h | 71 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_multi.cpp | 18 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder.cpp | 174 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator.cpp | 276 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator.h | 40 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp | 267 |
9 files changed, 1020 insertions, 64 deletions
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp index 010d7d38136..d94494f5097 100644 --- a/src/mongo/db/exec/sbe/expressions/expression.cpp +++ b/src/mongo/db/exec/sbe/expressions/expression.cpp @@ -787,6 +787,15 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = { BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggFirstNMerge, true}}, {"aggFirstNFinalize", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::aggFirstNFinalize, false}}, + {"aggTopN", BuiltinFn{[](size_t n) { return n == 3; }, vm::Builtin::aggTopN, true}}, + {"aggTopNMerge", BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggTopNMerge, true}}, + {"aggTopNFinalize", + BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggTopNFinalize, false}}, + {"aggBottomN", BuiltinFn{[](size_t n) { return n == 3; }, vm::Builtin::aggBottomN, true}}, + {"aggBottomNMerge", + BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggBottomNMerge, true}}, + {"aggBottomNFinalize", + BuiltinFn{[](size_t n) { return n == 2; }, vm::Builtin::aggBottomNFinalize, false}}, }; /** diff --git a/src/mongo/db/exec/sbe/values/value.h b/src/mongo/db/exec/sbe/values/value.h index 5e01ace708b..787609bb63b 100644 --- a/src/mongo/db/exec/sbe/values/value.h +++ b/src/mongo/db/exec/sbe/values/value.h @@ -845,6 +845,13 @@ public: } } + void pop_back() { + if (_vals.size() > 0) { + releaseValue(_vals.back().first, _vals.back().second); + _vals.pop_back(); + } + } + auto size() const noexcept { return _vals.size(); } diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index 97276df8c71..611f1ceb2d1 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -6039,8 +6039,8 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinArrayToObject(Ar return {true, objTag, objVal}; } -std::tuple<value::Array*, size_t, int32_t, int32_t> multiAccState(value::TypeTags accTag, - value::Value accVal) { +std::tuple<value::Array*, value::Array*, size_t, int32_t, int32_t> multiAccState( + value::TypeTags accTag, value::Value accVal) { uassert(7548600, "The accumulator state should be an array", accTag == value::TypeTags::Array); auto acc = value::getArrayView(accVal); @@ -6069,7 +6069,18 @@ std::tuple<value::Array*, size_t, int32_t, int32_t> multiAccState(value::TypeTag "MemLimit component should be a 32-bit integer", memLimitTag == value::TypeTags::NumberInt32); - return {array, maxSize, memUsage, memLimit}; + return {acc, array, maxSize, memUsage, memLimit}; +} + +void checkAndUpdateMemUsage(value::Array* accArray, int32_t memUsage, int32_t memLimit) { + uassert(ErrorCodes::ExceededMemoryLimit, + str::stream() + << "Accumulator used too much memory and spilling to disk cannot reduce memory " + "consumption any further. Memory limit: " + << memLimit << " bytes", + memUsage < memLimit); + accArray->setAt( + static_cast<size_t>(AggMultiElems::kMemUsage), value::TypeTags::NumberInt32, memUsage); } FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstN(ArityType arity) { @@ -6078,26 +6089,15 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstN(ArityT value::ValueGuard accGuard{accTag, accVal}; value::ValueGuard fieldGuard{fieldTag, fieldVal}; - auto [accArr, accSize, memUsage, memLimit] = multiAccState(accTag, accVal); + auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); - if (accArr->size() < accSize) { - // update memusage + if (array->size() < maxSize) { memUsage += value::getApproximateSize(fieldTag, fieldVal); - auto maxMemAllowed = memLimit; - uassert(ErrorCodes::ExceededMemoryLimit, - str::stream() - << "$firstN used too much memory and spilling to disk cannot reduce memory " - "consumption any further. Memory limit: " - << maxMemAllowed << " bytes", - memUsage < maxMemAllowed); + checkAndUpdateMemUsage(acc, memUsage, memLimit); // add to array fieldGuard.reset(); - accArr->push_back(fieldTag, fieldVal); - - // update the memUsageBytes - value::getArrayView(accVal)->setAt( - static_cast<size_t>(AggMultiElems::kMemUsage), value::TypeTags::NumberInt32, memUsage); + array->push_back(fieldTag, fieldVal); } accGuard.reset(); return {true, accTag, accVal}; @@ -6109,9 +6109,9 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNMerge(A value::ValueGuard mergeAccGuard{mergeAccTag, mergeAccVal}; value::ValueGuard accGuard{accTag, accVal}; - auto [mergeArr, mergeMaxSize, mergeMemUsage, mergeMemLimit] = + auto [mergeAcc, mergeArr, mergeMaxSize, mergeMemUsage, mergeMemLimit] = multiAccState(mergeAccTag, mergeAccVal); - auto [arr, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); + auto [acc, arr, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); uassert(7548604, "Two arrays to merge should have the same MaxSize component", maxSize == mergeMaxSize); @@ -6123,21 +6123,11 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNMerge(A value::ValueGuard valueGuard{tag, val}; mergeMemUsage += value::getApproximateSize(tag, val); - auto maxMemAllowed = mergeMemLimit; - uassert(ErrorCodes::ExceededMemoryLimit, - str::stream() - << "$firstN used too much memory and spilling to disk cannot reduce memory " - "consumption any further. Memory limit: " - << maxMemAllowed << " bytes", - mergeMemUsage < maxMemAllowed); + checkAndUpdateMemUsage(mergeAcc, mergeMemUsage, mergeMemLimit); valueGuard.reset(); mergeArr->push_back(tag, val); } - value::getArrayView(mergeAccVal) - ->setAt(static_cast<size_t>(AggMultiElems::kMemUsage), - value::TypeTags::NumberInt32, - mergeMemUsage); } mergeAccGuard.reset(); @@ -6156,6 +6146,152 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggFirstNFinaliz return {true, outputTag, outputVal}; } +template <typename Less> +int32_t aggTopBottomNAdd(value::Array* acc, + value::Array* array, + size_t maxSize, + int32_t memUsage, + int32_t memLimit, + const value::SortSpec* sortSpec, + std::pair<value::TypeTags, value::Value> key, + std::pair<value::TypeTags, value::Value> output) { + auto newMemUsage = [](int32_t memUsage, + std::pair<value::TypeTags, value::Value> key, + std::pair<value::TypeTags, value::Value> output) { + memUsage += value::getApproximateSize(key.first, key.second); + memUsage += value::getApproximateSize(output.first, output.second); + return memUsage; + }; + + value::ValueGuard keyGuard{key.first, key.second}; + value::ValueGuard outputGuard{output.first, output.second}; + auto less = Less(sortSpec); + auto keyLess = PairKeyComp(less); + auto& heap = array->values(); + + if (array->size() < maxSize) { + auto [pairTag, pairVal] = value::makeNewArray(); + value::ValueGuard pairGuard{pairTag, pairVal}; + auto pair = value::getArrayView(pairVal); + pair->reserve(2); + keyGuard.reset(); + pair->push_back(key.first, key.second); + outputGuard.reset(); + pair->push_back(output.first, output.second); + + memUsage = newMemUsage(memUsage, key, output); + checkAndUpdateMemUsage(acc, memUsage, memLimit); + + pairGuard.reset(); + array->push_back(pairTag, pairVal); + std::push_heap(heap.begin(), heap.end(), keyLess); + } else { + tassert(5807005, + "Heap should contain same number of elements as MaxSize", + array->size() == maxSize); + + auto [worstTag, worstVal] = heap.front(); + auto worst = value::getArrayView(worstVal); + auto worstKey = worst->getAt(0); + if (less(key, worstKey)) { + memUsage = newMemUsage(memUsage, key, output); + checkAndUpdateMemUsage(acc, memUsage, memLimit); + + std::pop_heap(heap.begin(), heap.end(), keyLess); + keyGuard.reset(); + worst->setAt(0, key.first, key.second); + outputGuard.reset(); + worst->setAt(1, output.first, output.second); + std::push_heap(heap.begin(), heap.end(), keyLess); + } + } + + return memUsage; +} + +template <typename Less> +FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomN(ArityType arity) { + auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(3); + tassert(5807024, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec); + auto sortSpec = value::getSortSpecView(sortSpecVal); + + auto [accTag, accVal] = moveOwnedFromStack(0); + value::ValueGuard accGuard{accTag, accVal}; + auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); + auto key = moveOwnedFromStack(1); + auto output = moveOwnedFromStack(2); + + aggTopBottomNAdd<Less>(acc, array, maxSize, memUsage, memLimit, sortSpec, key, output); + + accGuard.reset(); + return {true, accTag, accVal}; +} + +template <typename Less> +FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomNMerge( + ArityType arity) { + auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(2); + tassert(5807025, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec); + auto sortSpec = value::getSortSpecView(sortSpecVal); + + auto [accTag, accVal] = moveOwnedFromStack(1); + value::ValueGuard accGuard{accTag, accVal}; + auto [mergeAccTag, mergeAccVal] = moveOwnedFromStack(0); + value::ValueGuard mergeAccGuard{mergeAccTag, mergeAccVal}; + auto [mergeAcc, mergeArray, mergeMaxSize, mergeMemUsage, mergeMemLimit] = + multiAccState(mergeAccTag, mergeAccVal); + auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); + tassert(5807008, + "Two arrays to merge should have the same MaxSize component", + maxSize == mergeMaxSize); + + for (auto [pairTag, pairVal] : array->values()) { + auto pair = value::getArrayView(pairVal); + auto key = pair->swapAt(0, value::TypeTags::Null, 0); + auto output = pair->swapAt(1, value::TypeTags::Null, 0); + mergeMemUsage = aggTopBottomNAdd<Less>(mergeAcc, + mergeArray, + mergeMaxSize, + mergeMemUsage, + mergeMemLimit, + sortSpec, + key, + output); + } + + mergeAccGuard.reset(); + return {true, mergeAccTag, mergeAccVal}; +} + +FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinAggTopBottomNFinalize( + ArityType arity) { + auto [sortSpecOwned, sortSpecTag, sortSpecVal] = getFromStack(1); + tassert(5807026, "Argument must be of sortSpec type", sortSpecTag == value::TypeTags::sortSpec); + auto sortSpec = value::getSortSpecView(sortSpecVal); + + auto [accTag, accVal] = moveOwnedFromStack(0); + value::ValueGuard accGuard{accTag, accVal}; + auto [acc, array, maxSize, memUsage, memLimit] = multiAccState(accTag, accVal); + + auto [outputArrayTag, outputArrayVal] = value::makeNewArray(); + value::ValueGuard outputArrayGuard{outputArrayTag, outputArrayVal}; + auto outputArray = value::getArrayView(outputArrayVal); + outputArray->reserve(array->size()); + + // We always output result in the order of sort pattern in according to MQL semantics. + auto less = SortPatternLess(sortSpec); + auto keyLess = PairKeyComp(less); + std::sort(array->values().begin(), array->values().end(), keyLess); + for (size_t i = 0; i < array->size(); ++i) { + auto pair = value::getArrayView(array->getAt(i).second); + auto [outputTag, outputVal] = pair->swapAt(1, value::TypeTags::Null, 0); + outputArray->push_back(outputTag, outputVal); + } + + outputArrayGuard.reset(); + return {true, outputArrayTag, outputArrayVal}; +} + FastTuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin f, ArityType arity) { switch (f) { @@ -6428,6 +6564,18 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::dispatchBuiltin(Builtin return builtinAggFirstNMerge(arity); case Builtin::aggFirstNFinalize: return builtinAggFirstNFinalize(arity); + case Builtin::aggTopN: + return builtinAggTopBottomN<SortPatternLess>(arity); + case Builtin::aggTopNMerge: + return builtinAggTopBottomNMerge<SortPatternLess>(arity); + case Builtin::aggTopNFinalize: + return builtinAggTopBottomNFinalize(arity); + case Builtin::aggBottomN: + return builtinAggTopBottomN<SortPatternGreater>(arity); + case Builtin::aggBottomNMerge: + return builtinAggTopBottomNMerge<SortPatternGreater>(arity); + case Builtin::aggBottomNFinalize: + return builtinAggTopBottomNFinalize(arity); } MONGO_UNREACHABLE; @@ -6706,6 +6854,18 @@ std::string builtinToString(Builtin b) { return "aggFirstNMerge"; case Builtin::aggFirstNFinalize: return "aggFirstNFinalize"; + case Builtin::aggTopN: + return "aggTopN"; + case Builtin::aggTopNMerge: + return "aggTopNMerge"; + case Builtin::aggTopNFinalize: + return "aggTopNFinalize"; + case Builtin::aggBottomN: + return "aggBottomN"; + case Builtin::aggBottomNMerge: + return "aggBottomNMerge"; + case Builtin::aggBottomNFinalize: + return "aggBottomNFinalize"; default: MONGO_UNREACHABLE; } diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h index 2deb8b0c907..62b9be9154e 100644 --- a/src/mongo/db/exec/sbe/vm/vm.h +++ b/src/mongo/db/exec/sbe/vm/vm.h @@ -37,6 +37,7 @@ #include "mongo/config.h" #include "mongo/db/exec/sbe/makeobj_spec.h" #include "mongo/db/exec/sbe/values/slot.h" +#include "mongo/db/exec/sbe/values/sort_spec.h" #include "mongo/db/exec/sbe/values/value.h" #include "mongo/db/exec/sbe/vm/datetime.h" #include "mongo/db/exec/sbe/vm/label.h" @@ -757,6 +758,12 @@ enum class Builtin : uint8_t { aggFirstN, aggFirstNMerge, aggFirstNFinalize, + aggTopN, + aggTopNMerge, + aggTopNFinalize, + aggBottomN, + aggBottomNMerge, + aggBottomNFinalize, }; std::string builtinToString(Builtin b); @@ -773,6 +780,64 @@ std::string builtinToString(Builtin b); enum class AggMultiElems { kInternalArr, kMaxSize, kMemUsage, kMemLimit, kSizeOfArray }; /** + * Less than comparison based on a sort pattern. + */ +struct SortPatternLess { + SortPatternLess(const value::SortSpec* sortSpec) : _sortSpec(sortSpec) {} + + bool operator()(const std::pair<value::TypeTags, value::Value>& lhs, + const std::pair<value::TypeTags, value::Value>& rhs) const { + auto [cmpTag, cmpVal] = _sortSpec->compare(lhs.first, lhs.second, rhs.first, rhs.second); + uassert(5807000, "Invalid comparison result", cmpTag == value::TypeTags::NumberInt32); + return value::bitcastTo<int32_t>(cmpVal) < 0; + } + +private: + const value::SortSpec* _sortSpec; +}; + +/** + * Greater than comparison based on a sort pattern. + */ +struct SortPatternGreater { + SortPatternGreater(const value::SortSpec* sortSpec) : _sortSpec(sortSpec) {} + + bool operator()(const std::pair<value::TypeTags, value::Value>& lhs, + const std::pair<value::TypeTags, value::Value>& rhs) const { + auto [cmpTag, cmpVal] = _sortSpec->compare(lhs.first, lhs.second, rhs.first, rhs.second); + uassert(5807001, "Invalid comparison result", cmpTag == value::TypeTags::NumberInt32); + return value::bitcastTo<int32_t>(cmpVal) > 0; + } + +private: + const value::SortSpec* _sortSpec; +}; + +/** + * Comparison based on the key of a pair of elements. + */ +template <typename Comp> +struct PairKeyComp { + PairKeyComp(const Comp& comp) : _comp(comp) {} + + bool operator()(const std::pair<value::TypeTags, value::Value>& lhs, + const std::pair<value::TypeTags, value::Value>& rhs) const { + auto [lPairTag, lPairVal] = lhs; + auto lPair = value::getArrayView(lPairVal); + auto lKey = lPair->getAt(0); + + auto [rPairTag, rPairVal] = rhs; + auto rPair = value::getArrayView(rPairVal); + auto rKey = rPair->getAt(0); + + return _comp(lKey, rKey); + } + +private: + const Comp _comp; +}; + +/** * This enum defines indices into an 'Array' that returns the partial sum result when 'needsMerge' * is requested. * @@ -1586,13 +1651,17 @@ private: FastTuple<bool, value::TypeTags, value::Value> builtinISOWeekYear(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinISODayOfWeek(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinISOWeek(ArityType arity); - FastTuple<bool, value::TypeTags, value::Value> builtinObjectToArray(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinArrayToObject(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstN(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstNMerge(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> builtinAggFirstNFinalize(ArityType arity); + template <typename Less> + FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomN(ArityType arity); + template <typename Less> + FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomNMerge(ArityType arity); + FastTuple<bool, value::TypeTags, value::Value> builtinAggTopBottomNFinalize(ArityType arity); FastTuple<bool, value::TypeTags, value::Value> dispatchBuiltin(Builtin f, ArityType arity); diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp index b4e82dea93c..633304c4903 100644 --- a/src/mongo/db/pipeline/accumulator_multi.cpp +++ b/src/mongo/db/pipeline/accumulator_multi.cpp @@ -489,12 +489,13 @@ Document AccumulatorTopBottomN<sense, single>::serialize( } template <TopBottomSense sense> -std::pair<SortPattern, BSONArray> parseAccumulatorTopBottomNSortBy(ExpressionContext* const expCtx, - BSONObj sortBy) { +std::tuple<SortPattern, BSONArray, bool> parseAccumulatorTopBottomNSortBy( + ExpressionContext* const expCtx, BSONObj sortBy) { SortPattern sortPattern(sortBy, expCtx); BSONArrayBuilder sortFieldsExpBab; BSONObjIterator sortByBoi(sortBy); + bool hasMeta = false; for (const auto& part : sortPattern) { const auto fieldName = sortByBoi.next().fieldNameStringData(); if (part.expression) { @@ -505,21 +506,26 @@ std::pair<SortPattern, BSONArray> parseAccumulatorTopBottomNSortBy(ExpressionCon // sortFields array contains the data we need for sorting. const auto serialized = part.expression->serialize(false); sortFieldsExpBab.append(serialized.getDocument().toBson()); + hasMeta = true; } else { sortFieldsExpBab.append((StringBuilder() << "$" << fieldName).str()); } } - return {sortPattern, sortFieldsExpBab.arr()}; + return {sortPattern, sortFieldsExpBab.arr(), hasMeta}; } template <TopBottomSense sense, bool single> AccumulationExpression AccumulatorTopBottomN<sense, single>::parseTopBottomN( ExpressionContext* const expCtx, BSONElement elem, VariablesParseState vps) { - expCtx->sbeGroupCompatibility = SbeCompatibility::notCompatible; auto name = AccumulatorTopBottomN<sense, single>::getName(); const auto [n, output, sortBy] = accumulatorNParseArgs<single>(expCtx, elem, name.rawData(), true, vps); - auto [sortPattern, sortFieldsExp] = parseAccumulatorTopBottomNSortBy<sense>(expCtx, *sortBy); + auto [sortPattern, sortFieldsExp, hasMeta] = + parseAccumulatorTopBottomNSortBy<sense>(expCtx, *sortBy); + + auto sbeCompatibility = + hasMeta ? SbeCompatibility::notCompatible : SbeCompatibility::flagGuarded; + expCtx->sbeGroupCompatibility = std::min(expCtx->sbeGroupCompatibility, sbeCompatibility); // Construct argument expression. If given sortBy: {field1: 1, field2: 1} it will be shaped like // {output: <output expression>, sortFields: ["$field1", "$field2"]}. This projects out only the @@ -539,7 +545,7 @@ template <TopBottomSense sense, bool single> boost::intrusive_ptr<AccumulatorState> AccumulatorTopBottomN<sense, single>::create( ExpressionContext* expCtx, BSONObj sortBy, bool isRemovable) { return make_intrusive<AccumulatorTopBottomN<sense, single>>( - expCtx, parseAccumulatorTopBottomNSortBy<sense>(expCtx, sortBy).first, isRemovable); + expCtx, std::get<0>(parseAccumulatorTopBottomNSortBy<sense>(expCtx, sortBy)), isRemovable); } template <TopBottomSense sense, bool single> diff --git a/src/mongo/db/query/sbe_stage_builder.cpp b/src/mongo/db/query/sbe_stage_builder.cpp index e823ef7f428..34c6f858fa3 100644 --- a/src/mongo/db/query/sbe_stage_builder.cpp +++ b/src/mongo/db/query/sbe_stage_builder.cpp @@ -55,6 +55,7 @@ #include "mongo/db/matcher/expression_leaf.h" #include "mongo/db/matcher/match_expression_dependencies.h" #include "mongo/db/pipeline/abt/field_map_builder.h" +#include "mongo/db/pipeline/accumulator_multi.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_visitor.h" #include "mongo/db/query/bind_input_params.h" @@ -2305,6 +2306,45 @@ std::tuple<sbe::value::SlotVector, EvalStage, std::unique_ptr<sbe::EExpression>> } } +template <TopBottomSense sense, bool single> +std::unique_ptr<sbe::EExpression> getSortSpecFromTopBottomN( + const AccumulatorTopBottomN<sense, single>* acc) { + tassert(5807013, "Accumulator state must not be null", acc); + auto sortPattern = + acc->getSortPattern().serialize(SortPattern::SortKeySerialization::kForExplain).toBson(); + auto sortSpec = std::make_unique<sbe::value::SortSpec>(sortPattern); + auto sortSpecExpr = + makeConstant(sbe::value::TypeTags::sortSpec, + sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec.release())); + return sortSpecExpr; +} + +std::unique_ptr<sbe::EExpression> getSortSpecFromTopBottomN(const AccumulationStatement& accStmt) { + auto acc = accStmt.expr.factory(); + if (accStmt.expr.name == AccumulatorTopBottomN<kTop, true>::getName()) { + return getSortSpecFromTopBottomN( + dynamic_cast<AccumulatorTopBottomN<kTop, true>*>(acc.get())); + } else if (accStmt.expr.name == AccumulatorTopBottomN<kBottom, true>::getName()) { + return getSortSpecFromTopBottomN( + dynamic_cast<AccumulatorTopBottomN<kBottom, true>*>(acc.get())); + } else if (accStmt.expr.name == AccumulatorTopBottomN<kTop, false>::getName()) { + return getSortSpecFromTopBottomN( + dynamic_cast<AccumulatorTopBottomN<kTop, false>*>(acc.get())); + } else if (accStmt.expr.name == AccumulatorTopBottomN<kBottom, false>::getName()) { + return getSortSpecFromTopBottomN( + dynamic_cast<AccumulatorTopBottomN<kBottom, false>*>(acc.get())); + } else { + MONGO_UNREACHABLE; + } +} + +bool isTopBottomN(const AccumulationStatement& accStmt) { + return accStmt.expr.name == AccumulatorTopBottomN<kTop, true>::getName() || + accStmt.expr.name == AccumulatorTopBottomN<kBottom, true>::getName() || + accStmt.expr.name == AccumulatorTopBottomN<kTop, false>::getName() || + accStmt.expr.name == AccumulatorTopBottomN<kBottom, false>::getName(); +} + sbe::value::SlotVector generateAccumulator( StageBuilderState& state, const AccumulationStatement& accStmt, @@ -2313,16 +2353,67 @@ sbe::value::SlotVector generateAccumulator( sbe::HashAggStage::AggExprVector& aggSlotExprs, boost::optional<sbe::value::SlotId> initializerRootSlot) { auto rootSlot = outputs.getIfExists(PlanStageSlots::kResult); - auto argExpr = generateExpression(state, accStmt.expr.argument.get(), rootSlot, &outputs); - auto initExpr = - generateExpression(state, accStmt.expr.initializer.get(), initializerRootSlot, nullptr); + auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); // One accumulator may be translated to multiple accumulator expressions. For example, The // $avg will have two accumulators expressions, a sum(..) and a count which is implemented // as sum(1). - auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); - auto accExprs = stage_builder::buildAccumulator( - accStmt, argExpr.extractExpr(state), collatorSlot, *state.frameIdGenerator); + auto accExprs = [&]() { + // $topN/$bottomN accumulators require multiple arguments to the accumulator builder. + if (isTopBottomN(accStmt)) { + StringDataMap<std::unique_ptr<sbe::EExpression>> accArgs; + auto sortSpecExpr = getSortSpecFromTopBottomN(accStmt); + accArgs.emplace(AccArgs::kTopBottomNSortSpec, sortSpecExpr->clone()); + + // Build the key expression for the accumulator. + tassert(5807014, + str::stream() << accStmt.expr.name + << " accumulator must have the root slot set", + rootSlot); + auto key = collatorSlot ? makeFunction("generateCheapSortKey", + std::move(sortSpecExpr), + makeVariable(*rootSlot), + makeVariable(*collatorSlot)) + : makeFunction("generateCheapSortKey", + std::move(sortSpecExpr), + makeVariable(*rootSlot)); + accArgs.emplace(AccArgs::kTopBottomNKey, + makeFunction("sortKeyComponentVectorToArray", std::move(key))); + + // Build the value expression for the accumulator. + auto expObj = dynamic_cast<ExpressionObject*>(accStmt.expr.argument.get()); + tassert(5807015, + str::stream() << accStmt.expr.name + << " accumulator must have an object argument", + expObj); + for (auto& [key, value] : expObj->getChildExpressions()) { + if (key == AccumulatorN::kFieldNameOutput) { + auto outputExpr = generateExpression(state, value.get(), rootSlot, &outputs); + accArgs.emplace(AccArgs::kTopBottomNValue, + makeFillEmptyNull(outputExpr.extractExpr(state))); + break; + } + } + tassert(5807016, + str::stream() << accStmt.expr.name + << " accumulator must have an output field in the argument", + accArgs.find(AccArgs::kTopBottomNValue) != accArgs.end()); + + auto accExprs = stage_builder::buildAccumulator( + accStmt, std::move(accArgs), collatorSlot, *state.frameIdGenerator); + + return accExprs; + } else { + auto argExpr = + generateExpression(state, accStmt.expr.argument.get(), rootSlot, &outputs); + auto accExprs = stage_builder::buildAccumulator( + accStmt, argExpr.extractExpr(state), collatorSlot, *state.frameIdGenerator); + return accExprs; + } + }(); + + auto initExpr = + generateExpression(state, accStmt.expr.initializer.get(), initializerRootSlot, nullptr); auto accInitExprs = stage_builder::buildInitialize( accStmt, initExpr.extractExpr(state), *state.frameIdGenerator); @@ -2362,8 +2453,18 @@ sbe::SlotExprPairVector generateMergingExpressions(StageBuilderState& state, auto spillSlots = slotIdGenerator->generateMultiple(numInputSlots); auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); - auto mergingExprs = - buildCombinePartialAggregates(accStmt, spillSlots, collatorSlot, *frameIdGenerator); + + auto mergingExprs = [&]() { + if (isTopBottomN(accStmt)) { + StringDataMap<std::unique_ptr<sbe::EExpression>> mergeArgs; + mergeArgs.emplace(AccArgs::kTopBottomNSortSpec, getSortSpecFromTopBottomN(accStmt)); + return buildCombinePartialAggregates( + accStmt, spillSlots, std::move(mergeArgs), collatorSlot, *frameIdGenerator); + } else { + return buildCombinePartialAggregates( + accStmt, spillSlots, collatorSlot, *frameIdGenerator); + } + }(); // Zip the slot vector and expression vector into a vector of pairs. tassert(7039550, @@ -2409,13 +2510,33 @@ std::tuple<std::vector<std::string>, sbe::value::SlotVector, EvalStage> generate } }(); + auto collatorSlot = state.data->env->getSlotIfExists("collator"_sd); auto finalSlots{sbe::value::SlotVector{finalGroupBySlot}}; std::vector<std::string> fieldNames{"_id"}; size_t idxAccFirstSlot = dedupedGroupBySlots.size(); for (size_t idxAcc = 0; idxAcc < accStmts.size(); ++idxAcc) { // Gathers field names for the output object from accumulator statements. fieldNames.push_back(accStmts[idxAcc].fieldName); - auto finalExpr = stage_builder::buildFinalize(state, accStmts[idxAcc], aggSlotsVec[idxAcc]); + + auto finalExpr = [&]() { + const auto& accStmt = accStmts[idxAcc]; + if (isTopBottomN(accStmt)) { + StringDataMap<std::unique_ptr<sbe::EExpression>> finalArgs; + finalArgs.emplace(AccArgs::kTopBottomNSortSpec, getSortSpecFromTopBottomN(accStmt)); + return buildFinalize(state, + accStmts[idxAcc], + aggSlotsVec[idxAcc], + std::move(finalArgs), + collatorSlot, + *state.frameIdGenerator); + } else { + return buildFinalize(state, + accStmts[idxAcc], + aggSlotsVec[idxAcc], + collatorSlot, + *state.frameIdGenerator); + } + }(); // The final step may not return an expression if it's trivial. For example, $first and // $last's final steps are trivial. @@ -2512,6 +2633,7 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder if (!groupNode->needWholeDocument) { // Tracks whether we need to request kResult. bool rootDocIsNeeded = false; + bool sortKeyIsNeeded = false; auto referencesRoot = [&](const ExpressionFieldPath* fieldExpr) { rootDocIsNeeded = rootDocIsNeeded || fieldExpr->isROOT(); }; @@ -2520,23 +2642,29 @@ std::pair<std::unique_ptr<sbe::PlanStage>, PlanStageSlots> SlotBasedStageBuilder walkAndActOnFieldPaths(idExpr.get(), referencesRoot); for (const auto& accStmt : accStmts) { walkAndActOnFieldPaths(accStmt.expr.argument.get(), referencesRoot); + if (isTopBottomN(accStmt)) { + sortKeyIsNeeded = true; + } } - // If the group node doesn't have any dependency (e.g. $count) or if the dependency can be - // satisfied by the child node (e.g. covered index scan), we can clear the kResult - // requirement for the child. - if (groupNode->requiredFields.empty() || !rootDocIsNeeded) { - childReqs.clear(kResult); - } else if (childNode->getType() == StageType::STAGE_PROJECTION_COVERED) { - auto pn = static_cast<const ProjectionNodeCovered*>(childNode); - std::set<std::string> providedFieldSet; - for (auto&& elt : pn->coveredKeyObj) { - providedFieldSet.emplace(elt.fieldNameStringData()); - } - if (std::all_of(groupNode->requiredFields.begin(), - groupNode->requiredFields.end(), - [&](const std::string& f) { return providedFieldSet.count(f); })) { + // If any accumulator requires generating sort key, we cannot clear the kResult. + if (!sortKeyIsNeeded) { + // If the group node doesn't have any dependency (e.g. $count) or if the dependency can + // be satisfied by the child node (e.g. covered index scan), we can clear the kResult + // requirement for the child. + if (groupNode->requiredFields.empty() || !rootDocIsNeeded) { childReqs.clear(kResult); + } else if (childNode->getType() == StageType::STAGE_PROJECTION_COVERED) { + auto pn = static_cast<const ProjectionNodeCovered*>(childNode); + std::set<std::string> providedFieldSet; + for (auto&& elt : pn->coveredKeyObj) { + providedFieldSet.emplace(elt.fieldNameStringData()); + } + if (std::all_of(groupNode->requiredFields.begin(), + groupNode->requiredFields.end(), + [&](const std::string& f) { return providedFieldSet.count(f); })) { + childReqs.clear(kResult); + } } } } diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp index 5f68a8dc3e0..21f3457f9e7 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator.cpp +++ b/src/mongo/db/query/sbe_stage_builder_accumulator.cpp @@ -584,6 +584,7 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggsMergeObjec std::vector<std::unique_ptr<sbe::EExpression>> buildInitializeAccumulatorMulti( std::unique_ptr<sbe::EExpression> maxSizeExpr, sbe::value::FrameIdGenerator& frameIdGenerator) { + // Create an array of four elements [value holder, max size, memory used, memory limit]. std::vector<std::unique_ptr<sbe::EExpression>> aggs; auto maxAccumulatorBytes = internalQueryTopNAccumulatorBytes.load(); if (auto* maxSizeConstExpr = maxSizeExpr->as<sbe::EConstant>()) { @@ -663,12 +664,170 @@ std::unique_ptr<sbe::EExpression> buildFinalizeFirstN(StageBuilderState& state, return makeFunction("aggFirstNFinalize", makeVariable(inputSlots[0])); } +bool isAccumulatorTopN(const AccumulationExpression& expr) { + return expr.name == AccumulatorTopBottomN<kTop, false /* single */>::getName() || + expr.name == AccumulatorTopBottomN<kTop, true /* single */>::getName(); +} + +std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulatorTopBottomN( + const AccumulationExpression& expr, + StringDataMap<std::unique_ptr<sbe::EExpression>> args, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + auto it = args.find(AccArgs::kTopBottomNKey); + tassert(5807009, + str::stream() << "Accumulator " << expr.name << " expects a '" + << AccArgs::kTopBottomNKey << "' argument", + it != args.end()); + auto key = std::move(it->second); + + it = args.find(AccArgs::kTopBottomNValue); + tassert(5807010, + str::stream() << "Accumulator " << expr.name << " expects a '" + << AccArgs::kTopBottomNValue << "' argument", + it != args.end()); + auto value = std::move(it->second); + + it = args.find(AccArgs::kTopBottomNSortSpec); + tassert(5807021, + str::stream() << "Accumulator " << expr.name << " expects a '" + << AccArgs::kTopBottomNSortSpec << "' argument", + it != args.end()); + auto sortSpec = std::move(it->second); + + std::vector<std::unique_ptr<sbe::EExpression>> aggs; + aggs.push_back(makeFunction(isAccumulatorTopN(expr) ? "aggTopN" : "aggBottomN", + std::move(key), + std::move(value), + std::move(sortSpec))); + return aggs; +} + +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialTopBottomN( + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> args, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + tassert(5807011, + str::stream() << "Expected one input slot for merging " << expr.name + << ", got: " << inputSlots.size(), + inputSlots.size() == 1); + + auto it = args.find(AccArgs::kTopBottomNSortSpec); + tassert(5807022, + str::stream() << "Accumulator " << expr.name << " expects a '" + << AccArgs::kTopBottomNSortSpec << "' argument", + it != args.end()); + auto sortSpec = std::move(it->second); + + std::vector<std::unique_ptr<sbe::EExpression>> aggs; + aggs.push_back(makeFunction(isAccumulatorTopN(expr) ? "aggTopNMerge" : "aggBottomNMerge", + makeVariable(inputSlots[0]), + std::move(sortSpec))); + return aggs; +} + +std::unique_ptr<sbe::EExpression> buildFinalizeTopBottomNImpl( + StageBuilderState& state, + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> args, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator, + bool single) { + tassert(5807012, + str::stream() << "Expected one input slot for finalization of " << expr.name + << ", got: " << inputSlots.size(), + inputSlots.size() == 1); + auto inputVar = makeVariable(inputSlots[0]); + + auto it = args.find(AccArgs::kTopBottomNSortSpec); + tassert(5807023, + str::stream() << "Accumulator " << expr.name << " expects a '" + << AccArgs::kTopBottomNSortSpec << "' argument", + it != args.end()); + auto sortSpec = std::move(it->second); + + if (state.needsMerge) { + // When the data will be merged, the heap itself doesn't need to be sorted since the merging + // code will handle the sorting. + auto heapExpr = + makeFunction("getElement", + inputVar->clone(), + makeConstant(sbe::value::TypeTags::NumberInt32, + static_cast<int>(sbe::vm::AggMultiElems::kInternalArr))); + auto lambdaFrameId = frameIdGenerator.generate(); + auto pairVar = makeVariable(lambdaFrameId, 0); + auto lambdaExpr = sbe::makeE<sbe::ELocalLambda>( + lambdaFrameId, + makeNewObjFunction( + FieldPair{AccumulatorN::kFieldNameGeneratedSortKey, + makeFunction("getElement", + pairVar->clone(), + makeConstant(sbe::value::TypeTags::NumberInt32, 0))}, + FieldPair{AccumulatorN::kFieldNameOutput, + makeFunction("getElement", + pairVar->clone(), + makeConstant(sbe::value::TypeTags::NumberInt32, 1))})); + // Convert the array pair representation [key, output] to an object format that the merging + // code expects. + return makeFunction("traverseP", + std::move(heapExpr), + std::move(lambdaExpr), + makeConstant(sbe::value::TypeTags::NumberInt32, 1)); + } else { + auto finalExpr = + makeFunction(isAccumulatorTopN(expr) ? "aggTopNFinalize" : "aggBottomNFinalize", + inputVar->clone(), + std::move(sortSpec)); + if (single) { + finalExpr = makeFunction("getElement", + std::move(finalExpr), + makeConstant(sbe::value::TypeTags::NumberInt32, 0)); + } + return finalExpr; + } +} + +std::unique_ptr<sbe::EExpression> buildFinalizeTopBottomN( + StageBuilderState& state, + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> args, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + return buildFinalizeTopBottomNImpl(state, + expr, + inputSlots, + std::move(args), + collatorSlot, + frameIdGenerator, + false /* single */); +} + +std::unique_ptr<sbe::EExpression> buildFinalizeTopBottom( + StageBuilderState& state, + const AccumulationExpression& expr, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> args, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + return buildFinalizeTopBottomNImpl(state, + expr, + inputSlots, + std::move(args), + collatorSlot, + frameIdGenerator, + true /* single */); +} + template <int N> std::vector<std::unique_ptr<sbe::EExpression>> emptyInitializer( std::unique_ptr<sbe::EExpression> maxSizeExpr, sbe::value::FrameIdGenerator& frameIdGenerator) { return std::vector<std::unique_ptr<sbe::EExpression>>{N}; } -}; // namespace +} // namespace std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( const AccumulationStatement& acc, @@ -708,6 +867,37 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( frameIdGenerator); } +std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( + const AccumulationStatement& acc, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + using BuildAccumulatorFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>( + const AccumulationExpression&, + StringDataMap<std::unique_ptr<sbe::EExpression>>, + boost::optional<sbe::value::SlotId>, + sbe::value::FrameIdGenerator&)>; + + static const StringDataMap<BuildAccumulatorFn> kAccumulatorBuilders = { + {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildAccumulatorTopBottomN}, + {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), &buildAccumulatorTopBottomN}, + {AccumulatorTopBottomN<kTop, false /* single */>::getName(), &buildAccumulatorTopBottomN}, + {AccumulatorTopBottomN<kBottom, false /* single */>::getName(), + &buildAccumulatorTopBottomN}, + }; + + auto accExprName = acc.expr.name; + uassert(5807017, + str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName, + kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end()); + + return std::invoke(kAccumulatorBuilders.at(accExprName), + acc.expr, + std::move(argExprs), + collatorSlot, + frameIdGenerator); +} + std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( const AccumulationStatement& acc, const sbe::value::SlotVector& inputSlots, @@ -743,9 +933,47 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( kAggCombinerBuilders.at(accExprName), acc.expr, inputSlots, collatorSlot, frameIdGenerator); } +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( + const AccumulationStatement& acc, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + using BuildAggCombinerFn = std::function<std::vector<std::unique_ptr<sbe::EExpression>>( + const AccumulationExpression&, + const sbe::value::SlotVector&, + StringDataMap<std::unique_ptr<sbe::EExpression>>, + boost::optional<sbe::value::SlotId>, + sbe::value::FrameIdGenerator&)>; + + static const StringDataMap<BuildAggCombinerFn> kAggCombinerBuilders = { + {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildCombinePartialTopBottomN}, + {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), + &buildCombinePartialTopBottomN}, + {AccumulatorTopBottomN<kTop, false /* single */>::getName(), + &buildCombinePartialTopBottomN}, + {AccumulatorTopBottomN<kBottom, false /* single */>::getName(), + &buildCombinePartialTopBottomN}, + }; + + auto accExprName = acc.expr.name; + uassert(5807019, + str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName, + kAggCombinerBuilders.find(accExprName) != kAggCombinerBuilders.end()); + + return std::invoke(kAggCombinerBuilders.at(accExprName), + acc.expr, + inputSlots, + std::move(argExprs), + collatorSlot, + frameIdGenerator); +} + std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state, const AccumulationStatement& acc, - const sbe::value::SlotVector& aggSlots) { + const sbe::value::SlotVector& aggSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { using BuildFinalizeFn = std::function<std::unique_ptr<sbe::EExpression>( StageBuilderState&, const AccumulationExpression&, sbe::value::SlotVector)>; @@ -777,6 +1005,42 @@ std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state, } } +std::unique_ptr<sbe::EExpression> buildFinalize( + StageBuilderState& state, + const AccumulationStatement& acc, + const sbe::value::SlotVector& aggSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator) { + using BuildFinalizeFn = std::function<std::unique_ptr<sbe::EExpression>( + StageBuilderState&, + const AccumulationExpression&, + sbe::value::SlotVector, + StringDataMap<std::unique_ptr<sbe::EExpression>>, + boost::optional<sbe::value::SlotId>, + sbe::value::FrameIdGenerator&)>; + + static const StringDataMap<BuildFinalizeFn> kAccumulatorBuilders = { + {AccumulatorTopBottomN<kTop, true /* single */>::getName(), &buildFinalizeTopBottom}, + {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), &buildFinalizeTopBottom}, + {AccumulatorTopBottomN<kTop, false /* single */>::getName(), &buildFinalizeTopBottomN}, + {AccumulatorTopBottomN<kBottom, false /* single */>::getName(), &buildFinalizeTopBottomN}, + }; + + auto accExprName = acc.expr.name; + uassert(5807020, + str::stream() << "Unsupported Accumulator in SBE accumulator builder: " << accExprName, + kAccumulatorBuilders.find(accExprName) != kAccumulatorBuilders.end()); + + return std::invoke(kAccumulatorBuilders.at(accExprName), + state, + acc.expr, + aggSlots, + std::move(argExprs), + collatorSlot, + frameIdGenerator); +} + std::vector<std::unique_ptr<sbe::EExpression>> buildInitialize( const AccumulationStatement& acc, std::unique_ptr<sbe::EExpression> initExpr, @@ -797,6 +1061,14 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildInitialize( {AccumulatorStdDevPop::kName, &emptyInitializer<1>}, {AccumulatorStdDevSamp::kName, &emptyInitializer<1>}, {AccumulatorFirstN::kName, &buildInitializeAccumulatorMulti}, + {AccumulatorTopBottomN<kTop, true /* single */>::getName(), + &buildInitializeAccumulatorMulti}, + {AccumulatorTopBottomN<kBottom, true /* single */>::getName(), + &buildInitializeAccumulatorMulti}, + {AccumulatorTopBottomN<kTop, false /* single */>::getName(), + &buildInitializeAccumulatorMulti}, + {AccumulatorTopBottomN<kBottom, false /* single */>::getName(), + &buildInitializeAccumulatorMulti}, }; auto accExprName = acc.expr.name; diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator.h b/src/mongo/db/query/sbe_stage_builder_accumulator.h index 58daf6cd866..33d9008302b 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator.h +++ b/src/mongo/db/query/sbe_stage_builder_accumulator.h @@ -39,6 +39,12 @@ namespace mongo::stage_builder { class PlanStageSlots; +namespace AccArgs { +const StringData kTopBottomNSortSpec = "sortSpec"_sd; +const StringData kTopBottomNKey = "key"_sd; +const StringData kTopBottomNValue = "value"_sd; +} // namespace AccArgs + /** * Translates an input AccumulationStatement into an SBE EExpression for accumulation expressions. */ @@ -49,6 +55,15 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( sbe::value::FrameIdGenerator&); /** + * Similar to above but takes multiple arguments. + */ +std::vector<std::unique_ptr<sbe::EExpression>> buildAccumulator( + const AccumulationStatement& acc, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator&); + +/** * When SBE hash aggregation spills to disk, it spills partial aggregates which need to be combined * later. This function returns the expressions that can be used to combine partial aggregates for * the given accumulator 'acc'. The aggregate-of-aggregates will be stored in a slots owned by the @@ -62,13 +77,36 @@ std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( sbe::value::FrameIdGenerator&); /** + * Similar to above but takes multiple arguments. + */ +std::vector<std::unique_ptr<sbe::EExpression>> buildCombinePartialAggregates( + const AccumulationStatement& acc, + const sbe::value::SlotVector& inputSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator&); + +/** * Translates an input AccumulationStatement into an SBE EExpression that represents an * AccumulationStatement's finalization step. The 'stage' parameter provides the input subtree to * build on top of. */ std::unique_ptr<sbe::EExpression> buildFinalize(StageBuilderState& state, const AccumulationStatement& acc, - const sbe::value::SlotVector& aggSlots); + const sbe::value::SlotVector& aggSlots, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator); + +/** + * Similar to above but takes multiple arguments. + */ +std::unique_ptr<sbe::EExpression> buildFinalize( + StageBuilderState& state, + const AccumulationStatement& acc, + const sbe::value::SlotVector& aggSlots, + StringDataMap<std::unique_ptr<sbe::EExpression>> argExprs, + boost::optional<sbe::value::SlotId> collatorSlot, + sbe::value::FrameIdGenerator& frameIdGenerator); /** * Translates an input AccumulationStatement into an SBE EExpression for the initialization of the diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp index d5b52d3489e..b4ca88e919a 100644 --- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp +++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp @@ -33,6 +33,7 @@ #include "mongo/bson/bsonobj.h" #include "mongo/db/exec/sbe/expression_test_base.h" +#include "mongo/db/exec/sbe/vm/vm.h" #include "mongo/db/pipeline/document_source_group.h" #include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/collation/collator_interface_mock.h" @@ -1678,6 +1679,179 @@ TEST_F(SbeStageBuilderGroupTest, FirstNAccumulatorInvalidDynamicN) { static_cast<ErrorCodes::Error>(7548607)); } +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorSingleGroup) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4)), + BSON_ARRAY(BSON("a" << 33 << "s" << 3)), + BSON_ARRAY(BSON("a" << 22 << "s" << 2)), + BSON_ARRAY(BSON("a" << 11 << "s" << 1))}; + runGroupAggregationTest("{_id: null, x: {$top: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 11))); + runGroupAggregationTest("{_id: null, x: {$bottom: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 44))); + runGroupAggregationTest( + "{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22 << 33)))); + runGroupAggregationTest( + "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(22 << 33 << 44)))); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorCompoundSort) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s1" << 1 << "s2" << 1)), + BSON_ARRAY(BSON("a" << 12 << "s1" << 1 << "s2" << 2)), + BSON_ARRAY(BSON("a" << 21 << "s1" << 2 << "s2" << 1)), + BSON_ARRAY(BSON("a" << 22 << "s1" << 2 << "s2" << 2))}; + runGroupAggregationTest("{_id: null, x: {$top: {output: '$a', sortBy: {s1: 1, s2: -1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 12))); + runGroupAggregationTest("{_id: null, x: {$bottom: {output: '$a', sortBy: {s1: 1, s2: -1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 21))); + runGroupAggregationTest( + "{_id: null, x: {$topN: {output: '$a', sortBy: {s1: 1, s2: -1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(12 << 11 << 22)))); + runGroupAggregationTest( + "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s1: 1, s2: -1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22 << 21)))); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorCollation) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 41 << "s" + << "41")), + BSON_ARRAY(BSON("a" << 32 << "s" + << "32")), + BSON_ARRAY(BSON("a" << 23 << "s" + << "23")), + BSON_ARRAY(BSON("a" << 14 << "s" + << "14"))}; + runGroupAggregationTest( + "{_id: null, x: {$top: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 41)), + std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); + runGroupAggregationTest( + "{_id: null, x: {$bottom: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << 14)), + std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); + runGroupAggregationTest( + "{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(41 << 32 << 23))), + std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); + runGroupAggregationTest( + "{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(32 << 23 << 14))), + std::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorNotEnoughElement) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 22 << "s" << 2)), + BSON_ARRAY(BSON("a" << 11 << "s" << 1))}; + runGroupAggregationTest("{_id: null, x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22)))); + runGroupAggregationTest("{_id: null, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << BSONNULL << "x" << BSON_ARRAY(11 << 22)))); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorMultiGroup) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4 << "n" << 1)), + BSON_ARRAY(BSON("a" << 33 << "s" << 3 << "n" << 1)), + BSON_ARRAY(BSON("a" << 22 << "s" << 2 << "n" << 1)), + BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n" << 1)), + BSON_ARRAY(BSON("a" << 88 << "s" << 8 << "n" << 2)), + BSON_ARRAY(BSON("a" << 77 << "s" << 7 << "n" << 2)), + BSON_ARRAY(BSON("a" << 66 << "s" << 6 << "n" << 2)), + BSON_ARRAY(BSON("a" << 55 << "s" << 5 << "n" << 2))}; + runGroupAggregationTest( + "{_id: '$n', x: {$top: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << 1 << "x" << 11) << BSON("_id" << 2 << "x" << 55))); + runGroupAggregationTest( + "{_id: '$n', x: {$bottom: {output: '$a', sortBy: {s: 1}}}}", + docs, + BSON_ARRAY(BSON("_id" << 1 << "x" << 44) << BSON("_id" << 2 << "x" << 88))); + runGroupAggregationTest("{_id: '$n', x: {$topN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << 1 << "x" << BSON_ARRAY(11 << 22 << 33)) + << BSON("_id" << 2 << "x" << BSON_ARRAY(55 << 66 << 77)))); + runGroupAggregationTest("{_id: '$n', x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: 3}}}", + docs, + BSON_ARRAY(BSON("_id" << 1 << "x" << BSON_ARRAY(22 << 33 << 44)) + << BSON("_id" << 2 << "x" << BSON_ARRAY(66 << 77 << 88)))); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorDynamicN) { + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 44 << "s" << 4 << "n" << 2)), + BSON_ARRAY(BSON("a" << 33 << "s" << 3 << "n" << 2)), + BSON_ARRAY(BSON("a" << 22 << "s" << 2 << "n" << 2)), + BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n" << 2)), + BSON_ARRAY(BSON("a" << 88 << "s" << 8 << "n" << 3)), + BSON_ARRAY(BSON("a" << 77 << "s" << 7 << "n" << 3)), + BSON_ARRAY(BSON("a" << 66 << "s" << 6 << "n" << 3)), + BSON_ARRAY(BSON("a" << 55 << "s" << 5 << "n" << 3))}; + runGroupAggregationTest( + "{_id: {n1: '$n'}, x: {$topN: {output: '$a', sortBy: {s: 1}, n: '$n1'}}}", + docs, + BSON_ARRAY(BSON("_id" << BSON("n1" << 2) << "x" << BSON_ARRAY(11 << 22)) + << BSON("_id" << BSON("n1" << 3) << "x" << BSON_ARRAY(55 << 66 << 77)))); + runGroupAggregationTest( + "{_id: {n1: '$n'}, x: {$bottomN: {output: '$a', sortBy: {s: 1}, n: '$n1'}}}", + docs, + BSON_ARRAY(BSON("_id" << BSON("n1" << 2) << "x" << BSON_ARRAY(33 << 44)) + << BSON("_id" << BSON("n1" << 3) << "x" << BSON_ARRAY(66 << 77 << 88)))); +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorInvalidConstantN) { + const std::vector<std::string> accumulators{"$topN", "$bottomN"}; + const std::vector<std::string> testCases{"'string'", "4.2", "-1", "0"}; + auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s" << 1))}; + for (const auto& acc : accumulators) { + for (const auto& testCase : testCases) { + runGroupAggregationToFail(str::stream() << "{_id: null, x: {" << acc + << ": {output: '$a', sortBy: {s: 1}, n: " + << testCase << "}}}", + docs, + static_cast<ErrorCodes::Error>(7548606)); + } + } +} + +TEST_F(SbeStageBuilderGroupTest, TopBottomNAccumulatorInvalidDynamicN) { + const std::vector<std::string> accumulators{"$topN", "$bottomN"}; + const std::vector<BSONObj> testCases{BSON("n" + << "string"), + BSON("n" << 4.2), + BSON("n" << -1), + BSON("n" << 0)}; + for (const auto& acc : accumulators) { + for (const auto& testCase : testCases) { + auto docs = + std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 11 << "s" << 1 << "n1" << testCase))}; + + runGroupAggregationToFail(str::stream() + << "{_id: null, x: {" << acc + << ": {output: '$a', sortBy: {s: 1}, n: '$n'}}}", + docs, + static_cast<ErrorCodes::Error>(7548607)); + runGroupAggregationToFail(str::stream() + << "{_id: {n: '$n1.n'}, x: {" << acc + << ": {output: '$a', sortBy: {s: 1}, n: '$n'}}}", + docs, + static_cast<ErrorCodes::Error>(7548607)); + } + } +} + class AccumulatorSBEIncompatible final : public AccumulatorState { public: static constexpr auto kName = "$incompatible"_sd; @@ -1947,6 +2121,45 @@ public: return {resultTag, resultVal}; } + std::pair<sbe::value::TypeTags, sbe::value::Value> bsonArrayToSbe(BSONArray arr) { + auto [arrTag, arrVal] = sbe::value::makeNewArray(); + auto arrView = sbe::value::getArrayView(arrVal); + + for (auto elem : arr) { + auto [tag, val] = sbe::bson::convertFrom<false>(elem); + arrView->push_back(tag, val); + } + return {arrTag, arrVal}; + } + + /** + * Create an accumulator state for $topN/$bottomN, given heap in the format of BSONArray. + */ + std::pair<sbe::value::TypeTags, sbe::value::Value> makeTopBottomNAccumulatorState( + BSONArray valuesBson, long maxSize, const sbe::value::SortSpec* sortSpec) { + auto [stateTag, stateVal] = sbe::value::makeNewArray(); + sbe::value::ValueGuard stateGuard{stateTag, stateVal}; + auto state = sbe::value::getArrayView(stateVal); + + auto [valuesTag, valuesVal] = bsonArrayToSbe(valuesBson); + sbe::value::ValueGuard valuesGuard{valuesTag, valuesVal}; + // Heap + state->push_back(valuesTag, valuesVal); + + // Max size + state->push_back(sbe::value::TypeTags::NumberInt64, maxSize); + + // Memory usage + state->push_back(sbe::value::TypeTags::NumberInt32, 0); + + // Memory limit + state->push_back(sbe::value::TypeTags::NumberInt32, INT_MAX); + + valuesGuard.reset(); + stateGuard.reset(); + return {stateTag, stateVal}; + } + /** * Given the name of an SBE agg function ('aggFuncName') and an array of values expressed as a * BSON array, aggregates the values inside the array and returns the resulting SBE value. @@ -2621,4 +2834,58 @@ TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsFirstNInputArrayEm ASSERT_EQ(compareVal, 0); sbe::value::releaseValue(resultTag, resultVal); } + +TEST_F(SbeStageBuilderGroupAggCombinerTest, CombinePartialAggsTopBottomN) { + auto sortPattern = BSON("x" << 1); + auto sortSpec = new sbe::value::SortSpec(sortPattern); + auto sortSpecConstant = stage_builder::makeConstant( + sbe::value::TypeTags::sortSpec, sbe::value::bitcastFrom<sbe::value::SortSpec*>(sortSpec)); + + auto topNExpr = stage_builder::makeFunction( + "aggTopNMerge", stage_builder::makeVariable(_inputSlotId), sortSpecConstant->clone()); + auto bottomNExpr = stage_builder::makeFunction( + "aggBottomNMerge", stage_builder::makeVariable(_inputSlotId), sortSpecConstant->clone()); + + auto aggSlot = bindAccessor(&_aggAccessor); + auto topNFinalExpr = stage_builder::makeFunction( + "aggTopNFinalize", stage_builder::makeVariable(aggSlot), sortSpecConstant->clone()); + auto bottomNFinalExpr = stage_builder::makeFunction( + "aggBottomNFinalize", stage_builder::makeVariable(aggSlot), sortSpecConstant->clone()); + + std::vector<std::tuple<sbe::EExpression*, sbe::EExpression*, BSONArray, BSONArray, BSONArray>> + testCases{{topNExpr.get(), + topNFinalExpr.get(), + BSON_ARRAY(BSON_ARRAY(5 << 5) << BSON_ARRAY(3 << 3) << BSON_ARRAY(1 << 1)), + BSON_ARRAY(BSON_ARRAY(6 << 6) << BSON_ARRAY(4 << 4) << BSON_ARRAY(2 << 2)), + BSON_ARRAY(1 << 2 << 3)}, + {bottomNExpr.get(), + bottomNFinalExpr.get(), + BSON_ARRAY(BSON_ARRAY(1 << 1) << BSON_ARRAY(3 << 3) << BSON_ARRAY(5 << 5)), + BSON_ARRAY(BSON_ARRAY(2 << 2) << BSON_ARRAY(4 << 4) << BSON_ARRAY(6 << 6)), + BSON_ARRAY(4 << 5 << 6)}}; + + for (auto& [expr, finalExpr, heapMerge, heapIncoming, expected] : testCases) { + auto [accTag, accVal] = makeTopBottomNAccumulatorState(heapMerge, 3, sortSpec); + _aggAccessor.reset(true, accTag, accVal); + + auto [inputTag, inputVal] = makeTopBottomNAccumulatorState(heapIncoming, 3, sortSpec); + _inputAccessor.reset(true, inputTag, inputVal); + + auto compiledExpr = compileAggExpression(*expr, &_aggAccessor); + + auto [newAccTag, newAccVal] = runCompiledExpression(compiledExpr.get()); + _aggAccessor.reset(true, newAccTag, newAccVal); + + auto compiledFinalExpr = compileExpression(*finalExpr); + + auto [resultTag, resultVal] = runCompiledExpression(compiledFinalExpr.get()); + + auto [expectedTag, expectedVal] = bsonArrayToSbe(expected); + auto [compareTag, compareVal] = + sbe::value::compareValue(resultTag, resultVal, expectedTag, expectedVal); + + ASSERT_EQ(compareTag, sbe::value::TypeTags::NumberInt32); + ASSERT_EQ(compareVal, 0); + } +} } // namespace mongo |