diff options
Diffstat (limited to 'src/mongo/db/exec/sbe/vm/vm.cpp')
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.cpp | 136 |
1 files changed, 136 insertions, 0 deletions
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index 57687160f24..ea9695c63b4 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -112,6 +112,8 @@ int Instruction::stackOffset[Instruction::Tags::lastInstruction] = { 0, // getArraySize -1, // aggSum + -1, // aggDoubleDoubleSum + 0, // doubleDoubleSumFinalize -1, // aggMin -1, // aggMax -1, // aggFirst @@ -364,6 +366,14 @@ void CodeFragment::appendSum() { appendSimpleInstruction(Instruction::aggSum); } +void CodeFragment::appendDoubleDoubleSum() { + appendSimpleInstruction(Instruction::aggDoubleDoubleSum); +} + +void CodeFragment::appendDoubleDoubleSumFinalize() { + appendSimpleInstruction(Instruction::doubleDoubleSumFinalize); +} + void CodeFragment::appendMin() { appendSimpleInstruction(Instruction::aggMin); } @@ -899,6 +909,103 @@ std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggSum(value::TypeTags return genericAdd(accTag, accValue, fieldTag, fieldValue); } +std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggDoubleDoubleSum( + value::TypeTags accTag, + value::Value accValue, + value::TypeTags fieldTag, + value::Value fieldValue) { + // Skip aggregation step if we don't have the input. + if (fieldTag == value::TypeTags::Nothing) { + auto [tag, val] = value::copyValue(accTag, accValue); + return {true, tag, val}; + } + + // Initialize the accumulator. + if (accTag == value::TypeTags::Nothing) { + auto [accTagN, accValueN] = value::makeNewArray(); + value::ValueGuard guard{accTagN, accValueN}; + auto arr = value::getArrayView(accValueN); + + // The order of the following three elements should match to 'AggSumValueElems'. + arr->push_back(value::TypeTags::NumberInt32, value::bitcastFrom<int32_t>(0)); + arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0)); + arr->push_back(value::TypeTags::NumberDouble, value::bitcastFrom<double>(0.0)); + // The absent 'kDecimalTotal' element means that we've not seen any decimal value. So, we're + // not adding 'kDecimalTotal' element yet. + return aggDoubleDoubleSumImpl(accTagN, accValueN, fieldTag, fieldValue); + } + + return aggDoubleDoubleSumImpl(accTag, accValue, fieldTag, fieldValue); +} + +std::tuple<bool, value::TypeTags, value::Value> ByteCode::doubleDoubleSumFinalize( + value::TypeTags fieldTag, value::Value fieldValue) { + auto arr = value::getArrayView(fieldValue); + tassert(5755321, + str::stream() << "The result slot must have at least " + << AggSumValueElems::kMaxSizeOfArray - 1 + << " elements but got: " << arr->size(), + arr->size() >= AggSumValueElems::kMaxSizeOfArray - 1); + + auto nonDecimalTotalTag = arr->getAt(AggSumValueElems::kNonDecimalTotalTag).first; + tassert(5755322, + "The nonDecimalTag can't be NumberDecimal", + nonDecimalTotalTag != value::TypeTags::NumberDecimal); + auto [sumTag, sum] = arr->getAt(AggSumValueElems::kNonDecimalTotalSum); + auto [addendTag, addend] = arr->getAt(AggSumValueElems::kNonDecimalTotalAddend); + tassert(5755323, + "The sum and addend must be NumbetDouble", + sumTag == addendTag && sumTag == value::TypeTags::NumberDouble); + + // We're guaranteed to always have a valid nonDecimalTotal value. + auto nonDecimalTotal = DoubleDoubleSummation::create(value::bitcastTo<double>(sum), + value::bitcastTo<double>(addend)); + + if (auto nElems = arr->size(); nElems < AggSumValueElems::kMaxSizeOfArray) { + // We've not seen any decimal value. + switch (nonDecimalTotalTag) { + case value::TypeTags::NumberInt32: + case value::TypeTags::NumberInt64: + if (nonDecimalTotal.fitsLong()) { + auto longVal = nonDecimalTotal.getLong(); + if (int intVal = longVal; + nonDecimalTotalTag == value::TypeTags::NumberInt32 && intVal == longVal) { + return {true, + value::TypeTags::NumberInt32, + value::bitcastFrom<int32_t>(intVal)}; + } else { + return {true, + value::TypeTags::NumberInt64, + value::bitcastFrom<int64_t>(longVal)}; + } + } + // Sum doesn't fit a NumberLong, so return a NumberDouble instead. + [[fallthrough]]; + case value::TypeTags::NumberDouble: + return {true, + value::TypeTags::NumberDouble, + value::bitcastFrom<double>(nonDecimalTotal.getDouble())}; + default: + MONGO_UNREACHABLE_TASSERT(5755324); + } + } else { + // We've seen a decimal value. + tassert(5755325, + str::stream() << "The result slot must have at most " + << AggSumValueElems::kMaxSizeOfArray + << " elements but got: " << arr->size(), + nElems == AggSumValueElems::kMaxSizeOfArray); + auto [decimalTotalTag, decimalTotalVal] = arr->getAt(AggSumValueElems::kDecimalTotal); + tassert(5755326, + "The decimalTotal must be NumberDecimal", + decimalTotalTag == value::TypeTags::NumberDecimal); + + auto decimalTotal = value::bitcastTo<Decimal128>(decimalTotalVal); + auto [tag, val] = value::makeCopyDecimal(decimalTotal.add(nonDecimalTotal.getDecimal())); + return {true, tag, val}; + } +} + std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggMin(value::TypeTags accTag, value::Value accValue, value::TypeTags fieldTag, @@ -4381,6 +4488,35 @@ void ByteCode::runInternal(const CodeFragment* code, int64_t position) { } break; } + case Instruction::aggDoubleDoubleSum: { + auto [rhsOwned, rhsTag, rhsVal] = getFromStack(0); + popStack(); + auto [lhsOwned, lhsTag, lhsVal] = getFromStack(0); + + auto [owned, tag, val] = aggDoubleDoubleSum(lhsTag, lhsVal, rhsTag, rhsVal); + + topStack(owned, tag, val); + + if (rhsOwned) { + value::releaseValue(rhsTag, rhsVal); + } + if (lhsOwned) { + value::releaseValue(lhsTag, lhsVal); + } + break; + } + case Instruction::doubleDoubleSumFinalize: { + auto [sumArrayOwned, sumArrayTag, sumArrayVal] = getFromStack(0); + auto [finalSumOwned, finalSumTag, finalSumVal] = + doubleDoubleSumFinalize(sumArrayTag, sumArrayVal); + + topStack(finalSumOwned, finalSumTag, finalSumVal); + + if (sumArrayOwned) { + value::releaseValue(sumArrayTag, sumArrayVal); + } + break; + } case Instruction::aggMin: { auto [rhsOwned, rhsTag, rhsVal] = getFromStack(0); popStack(); |