diff options
Diffstat (limited to 'src/mongo/db/exec/sbe/vm/arith.cpp')
-rw-r--r-- | src/mongo/db/exec/sbe/vm/arith.cpp | 112 |
1 files changed, 112 insertions, 0 deletions
diff --git a/src/mongo/db/exec/sbe/vm/arith.cpp b/src/mongo/db/exec/sbe/vm/arith.cpp index e4c67775ad8..1d41ff59ce1 100644 --- a/src/mongo/db/exec/sbe/vm/arith.cpp +++ b/src/mongo/db/exec/sbe/vm/arith.cpp @@ -503,6 +503,57 @@ void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr, } } +void ByteCode::aggMergeDoubleDoubleSumsImpl(value::Array* accumulator, + value::TypeTags rhsTag, + value::Value rhsValue) { + auto [accumWidestType, _1] = accumulator->getAt(AggSumValueElems::kNonDecimalTotalTag); + + tassert(7039532, "value must be of type 'Array'", rhsTag == value::TypeTags::Array); + auto nextDoubleDoubleArr = value::getArrayView(rhsValue); + + tassert(7039533, + "array does not have enough elements", + nextDoubleDoubleArr->size() >= AggSumValueElems::kMaxSizeOfArray - 1); + + // First aggregate the non-decimal sum, then the non-decimal addend. Both should be doubles. + auto [sumTag, sum] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalSum); + tassert(7039534, "expected 'NumberDouble'", sumTag == value::TypeTags::NumberDouble); + aggDoubleDoubleSumImpl(accumulator, sumTag, sum); + + auto [addendTag, addend] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalAddend); + tassert(7039535, "expected 'NumberDouble'", addendTag == value::TypeTags::NumberDouble); + // There is a special case when the 'sum' is infinite and the 'addend' is NaN. This DoubleDouble + // value represents infinity, not NaN. Therefore, we avoid incorporating the NaN 'addend' value + // into the sum. + if (std::isfinite(value::bitcastTo<double>(sum)) || + !std::isnan(value::bitcastTo<double>(addend))) { + aggDoubleDoubleSumImpl(accumulator, addendTag, addend); + } + + // Determine the widest non-decimal type that we've seen so far, and set the accumulator state + // accordingly. We do this after computing the sums, since 'aggDoubleDoubleSumImpl()' will + // set the widest type to 'NumberDouble' when we call it above. + auto [newValWidestType, _2] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalTag); + tassert( + 7039536, "unexpected 'NumberDecimal'", newValWidestType != value::TypeTags::NumberDecimal); + tassert( + 7039537, "unexpected 'NumberDecimal'", accumWidestType != value::TypeTags::NumberDecimal); + auto widestType = getWidestNumericalType(newValWidestType, accumWidestType); + accumulator->setAt( + AggSumValueElems::kNonDecimalTotalTag, widestType, value::bitcastFrom<int32_t>(0)); + + // If there's a decimal128 sum as part of the incoming DoubleDouble sum, incorporate it into the + // accumulator. + if (nextDoubleDoubleArr->size() == AggSumValueElems::kMaxSizeOfArray) { + auto [decimalTotalTag, decimalTotalVal] = + nextDoubleDoubleArr->getAt(AggSumValueElems::kDecimalTotal); + tassert(7039538, + "The decimalTotal must be 'NumberDecimal'", + decimalTotalTag == TypeTags::NumberDecimal); + aggDoubleDoubleSumImpl(accumulator, decimalTotalTag, decimalTotalVal); + } +} + void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue) { if (!isNumber(rhsTag)) { return; @@ -551,6 +602,67 @@ void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::V return setStdDevArray(newCountVal, newMeanVal, newM2Val, arr); } +void ByteCode::aggMergeStdDevsImpl(value::Array* accumulator, + value::TypeTags rhsTag, + value::Value rhsValue) { + tassert(7039542, "expected value of type 'Array'", rhsTag == value::TypeTags::Array); + auto nextArr = value::getArrayView(rhsValue); + + tassert(7039543, + "expected array to have exactly 3 elements", + accumulator->size() == AggStdDevValueElems::kSizeOfArray); + tassert(7039544, + "expected array to have exactly 3 elements", + nextArr->size() == AggStdDevValueElems::kSizeOfArray); + + auto [newCountTag, newCountVal] = nextArr->getAt(AggStdDevValueElems::kCount); + tassert(7039545, "expected 64-bit int", newCountTag == value::TypeTags::NumberInt64); + int64_t newCount = value::bitcastTo<int64_t>(newCountVal); + + // If the incoming partial aggregate has a count of zero, then it represents the partial + // standard deviation of no data points. This means that it can be safely ignored, and we return + // the accumulator as is. + if (newCount == 0) { + return; + } + + auto [oldCountTag, oldCountVal] = accumulator->getAt(AggStdDevValueElems::kCount); + tassert(7039546, "expected 64-bit int", oldCountTag == value::TypeTags::NumberInt64); + int64_t oldCount = value::bitcastTo<int64_t>(oldCountVal); + + auto [oldMeanTag, oldMeanVal] = accumulator->getAt(AggStdDevValueElems::kRunningMean); + tassert(7039547, "expected double", oldMeanTag == value::TypeTags::NumberDouble); + double oldMean = value::bitcastTo<double>(oldMeanVal); + + auto [newMeanTag, newMeanVal] = nextArr->getAt(AggStdDevValueElems::kRunningMean); + tassert(7039548, "expected double", newMeanTag == value::TypeTags::NumberDouble); + double newMean = value::bitcastTo<double>(newMeanVal); + + auto [oldM2Tag, oldM2Val] = accumulator->getAt(AggStdDevValueElems::kRunningM2); + tassert(7039531, "expected double", oldM2Tag == value::TypeTags::NumberDouble); + double oldM2 = value::bitcastTo<double>(oldM2Val); + + auto [newM2Tag, newM2Val] = nextArr->getAt(AggStdDevValueElems::kRunningM2); + tassert(7039541, "expected double", newM2Tag == value::TypeTags::NumberDouble); + double newM2 = value::bitcastTo<double>(newM2Val); + + const double delta = newMean - oldMean; + // We've already handled the case where 'newCount' is zero above. This means that 'totalCount' + // must be positive, and prevents us from ever dividing by zero in the subsequent calculation. + int64_t totalCount = oldCount + newCount; + if (delta != 0) { + newMean = ((oldCount * oldMean) + (newCount * newMean)) / totalCount; + newM2 += delta * delta * + (static_cast<double>(oldCount) * static_cast<double>(newCount) / totalCount); + } + newM2 += oldM2; + + setStdDevArray(value::bitcastFrom<int64_t>(totalCount), + value::bitcastFrom<double>(newMean), + value::bitcastFrom<double>(newM2), + accumulator); +} + std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggStdDevFinalizeImpl( value::Value fieldValue, bool isSamp) { auto arr = value::getArrayView(fieldValue); |