summaryrefslogtreecommitdiff
path: root/src/mongo/db/exec/sbe/vm/arith.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/exec/sbe/vm/arith.cpp')
-rw-r--r--src/mongo/db/exec/sbe/vm/arith.cpp112
1 files changed, 112 insertions, 0 deletions
diff --git a/src/mongo/db/exec/sbe/vm/arith.cpp b/src/mongo/db/exec/sbe/vm/arith.cpp
index e4c67775ad8..1d41ff59ce1 100644
--- a/src/mongo/db/exec/sbe/vm/arith.cpp
+++ b/src/mongo/db/exec/sbe/vm/arith.cpp
@@ -503,6 +503,57 @@ void ByteCode::aggDoubleDoubleSumImpl(value::Array* arr,
}
}
+void ByteCode::aggMergeDoubleDoubleSumsImpl(value::Array* accumulator,
+ value::TypeTags rhsTag,
+ value::Value rhsValue) {
+ auto [accumWidestType, _1] = accumulator->getAt(AggSumValueElems::kNonDecimalTotalTag);
+
+ tassert(7039532, "value must be of type 'Array'", rhsTag == value::TypeTags::Array);
+ auto nextDoubleDoubleArr = value::getArrayView(rhsValue);
+
+ tassert(7039533,
+ "array does not have enough elements",
+ nextDoubleDoubleArr->size() >= AggSumValueElems::kMaxSizeOfArray - 1);
+
+ // First aggregate the non-decimal sum, then the non-decimal addend. Both should be doubles.
+ auto [sumTag, sum] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalSum);
+ tassert(7039534, "expected 'NumberDouble'", sumTag == value::TypeTags::NumberDouble);
+ aggDoubleDoubleSumImpl(accumulator, sumTag, sum);
+
+ auto [addendTag, addend] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalAddend);
+ tassert(7039535, "expected 'NumberDouble'", addendTag == value::TypeTags::NumberDouble);
+ // There is a special case when the 'sum' is infinite and the 'addend' is NaN. This DoubleDouble
+ // value represents infinity, not NaN. Therefore, we avoid incorporating the NaN 'addend' value
+ // into the sum.
+ if (std::isfinite(value::bitcastTo<double>(sum)) ||
+ !std::isnan(value::bitcastTo<double>(addend))) {
+ aggDoubleDoubleSumImpl(accumulator, addendTag, addend);
+ }
+
+ // Determine the widest non-decimal type that we've seen so far, and set the accumulator state
+ // accordingly. We do this after computing the sums, since 'aggDoubleDoubleSumImpl()' will
+ // set the widest type to 'NumberDouble' when we call it above.
+ auto [newValWidestType, _2] = nextDoubleDoubleArr->getAt(AggSumValueElems::kNonDecimalTotalTag);
+ tassert(
+ 7039536, "unexpected 'NumberDecimal'", newValWidestType != value::TypeTags::NumberDecimal);
+ tassert(
+ 7039537, "unexpected 'NumberDecimal'", accumWidestType != value::TypeTags::NumberDecimal);
+ auto widestType = getWidestNumericalType(newValWidestType, accumWidestType);
+ accumulator->setAt(
+ AggSumValueElems::kNonDecimalTotalTag, widestType, value::bitcastFrom<int32_t>(0));
+
+ // If there's a decimal128 sum as part of the incoming DoubleDouble sum, incorporate it into the
+ // accumulator.
+ if (nextDoubleDoubleArr->size() == AggSumValueElems::kMaxSizeOfArray) {
+ auto [decimalTotalTag, decimalTotalVal] =
+ nextDoubleDoubleArr->getAt(AggSumValueElems::kDecimalTotal);
+ tassert(7039538,
+ "The decimalTotal must be 'NumberDecimal'",
+ decimalTotalTag == TypeTags::NumberDecimal);
+ aggDoubleDoubleSumImpl(accumulator, decimalTotalTag, decimalTotalVal);
+ }
+}
+
void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::Value rhsValue) {
if (!isNumber(rhsTag)) {
return;
@@ -551,6 +602,67 @@ void ByteCode::aggStdDevImpl(value::Array* arr, value::TypeTags rhsTag, value::V
return setStdDevArray(newCountVal, newMeanVal, newM2Val, arr);
}
+void ByteCode::aggMergeStdDevsImpl(value::Array* accumulator,
+ value::TypeTags rhsTag,
+ value::Value rhsValue) {
+ tassert(7039542, "expected value of type 'Array'", rhsTag == value::TypeTags::Array);
+ auto nextArr = value::getArrayView(rhsValue);
+
+ tassert(7039543,
+ "expected array to have exactly 3 elements",
+ accumulator->size() == AggStdDevValueElems::kSizeOfArray);
+ tassert(7039544,
+ "expected array to have exactly 3 elements",
+ nextArr->size() == AggStdDevValueElems::kSizeOfArray);
+
+ auto [newCountTag, newCountVal] = nextArr->getAt(AggStdDevValueElems::kCount);
+ tassert(7039545, "expected 64-bit int", newCountTag == value::TypeTags::NumberInt64);
+ int64_t newCount = value::bitcastTo<int64_t>(newCountVal);
+
+ // If the incoming partial aggregate has a count of zero, then it represents the partial
+ // standard deviation of no data points. This means that it can be safely ignored, and we return
+ // the accumulator as is.
+ if (newCount == 0) {
+ return;
+ }
+
+ auto [oldCountTag, oldCountVal] = accumulator->getAt(AggStdDevValueElems::kCount);
+ tassert(7039546, "expected 64-bit int", oldCountTag == value::TypeTags::NumberInt64);
+ int64_t oldCount = value::bitcastTo<int64_t>(oldCountVal);
+
+ auto [oldMeanTag, oldMeanVal] = accumulator->getAt(AggStdDevValueElems::kRunningMean);
+ tassert(7039547, "expected double", oldMeanTag == value::TypeTags::NumberDouble);
+ double oldMean = value::bitcastTo<double>(oldMeanVal);
+
+ auto [newMeanTag, newMeanVal] = nextArr->getAt(AggStdDevValueElems::kRunningMean);
+ tassert(7039548, "expected double", newMeanTag == value::TypeTags::NumberDouble);
+ double newMean = value::bitcastTo<double>(newMeanVal);
+
+ auto [oldM2Tag, oldM2Val] = accumulator->getAt(AggStdDevValueElems::kRunningM2);
+ tassert(7039531, "expected double", oldM2Tag == value::TypeTags::NumberDouble);
+ double oldM2 = value::bitcastTo<double>(oldM2Val);
+
+ auto [newM2Tag, newM2Val] = nextArr->getAt(AggStdDevValueElems::kRunningM2);
+ tassert(7039541, "expected double", newM2Tag == value::TypeTags::NumberDouble);
+ double newM2 = value::bitcastTo<double>(newM2Val);
+
+ const double delta = newMean - oldMean;
+ // We've already handled the case where 'newCount' is zero above. This means that 'totalCount'
+ // must be positive, and prevents us from ever dividing by zero in the subsequent calculation.
+ int64_t totalCount = oldCount + newCount;
+ if (delta != 0) {
+ newMean = ((oldCount * oldMean) + (newCount * newMean)) / totalCount;
+ newM2 += delta * delta *
+ (static_cast<double>(oldCount) * static_cast<double>(newCount) / totalCount);
+ }
+ newM2 += oldM2;
+
+ setStdDevArray(value::bitcastFrom<int64_t>(totalCount),
+ value::bitcastFrom<double>(newMean),
+ value::bitcastFrom<double>(newM2),
+ accumulator);
+}
+
std::tuple<bool, value::TypeTags, value::Value> ByteCode::aggStdDevFinalizeImpl(
value::Value fieldValue, bool isSamp) {
auto arr = value::getArrayView(fieldValue);