summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/accumulator_avg.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/pipeline/accumulator_avg.cpp')
-rw-r--r--src/mongo/db/pipeline/accumulator_avg.cpp87
1 files changed, 67 insertions, 20 deletions
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp
index daddd56f502..a12b0da4361 100644
--- a/src/mongo/db/pipeline/accumulator_avg.cpp
+++ b/src/mongo/db/pipeline/accumulator_avg.cpp
@@ -29,6 +29,7 @@
#include "mongo/platform/basic.h"
+#include "mongo/db/exec/sbe/accumulator_sum_value_enum.h"
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/exec/document_value/document.h"
@@ -49,33 +50,73 @@ REGISTER_STABLE_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse
REGISTER_REMOVABLE_WINDOW_FUNCTION(avg, AccumulatorAvg, WindowFunctionAvg);
namespace {
+// TODO SERVER-64227 Remove 'subTotal' and 'subTotalError' fields when we branch for 6.1 because all
+// nodes in a sharded cluster would use the new data format.
const char subTotalName[] = "subTotal";
const char subTotalErrorName[] = "subTotalError"; // Used for extra precision
+const char partialSumName[] = "ps"; // Used for the full state of partial sum
const char countName[] = "count";
} // namespace
+void applyPartialSum(const std::vector<Value>& arr,
+ BSONType& nonDecimalTotalType,
+ BSONType& totalType,
+ DoubleDoubleSummation& nonDecimalTotal,
+ Decimal128& decimalTotal);
+
+Value serializePartialSum(BSONType nonDecimalTotalType,
+ BSONType totalType,
+ const DoubleDoubleSummation& nonDecimalTotal,
+ const Decimal128& decimalTotal);
+
void AccumulatorAvg::processInternal(const Value& input, bool merging) {
if (merging) {
// We expect an object that contains both a subtotal and a count. Additionally there may
// be an error value, that allows for additional precision.
// 'input' is what getValue(true) produced below.
verify(input.getType() == Object);
- // We're recursively adding the subtotal to get the proper type treatment, but this only
- // increments the count by one, so adjust the count afterwards. Similarly for 'error'.
- processInternal(input[subTotalName], false);
- _count += input[countName].getLong() - 1;
- Value error = input[subTotalErrorName];
- if (!error.missing()) {
- processInternal(error, false);
- _count--; // The error correction only adjusts the total, not the number of items.
+
+ // TODO SERVER-64227 Remove 'if' block when we branch for 6.1 because all nodes in a sharded
+ // cluster would use the new data format.
+ if (auto partialSumVal = input[partialSumName]; partialSumVal.missing()) {
+ // We're recursively adding the subtotal to get the proper type treatment, but this only
+ // increments the count by one, so adjust the count afterwards. Similarly for 'error'.
+ processInternal(input[subTotalName], false);
+ _count += input[countName].getLong() - 1;
+ Value error = input[subTotalErrorName];
+ if (!error.missing()) {
+ processInternal(error, false);
+ _count--; // The error correction only adjusts the total, not the number of items.
+ }
+ } else {
+ // The merge-side must be ready to process the full state of a partial sum from a
+ // shard-side if a shard chooses to do so. See Accumulator::getValue() for details.
+ applyPartialSum(partialSumVal.getArray(),
+ _nonDecimalTotalType,
+ _totalType,
+ _nonDecimalTotal,
+ _decimalTotal);
+ _count += input[countName].getLong();
}
+
return;
}
+ if (!input.numeric()) {
+ return;
+ }
+
+ _totalType = Value::getWidestNumeric(_totalType, input.getType());
+
+ // Keep the nonDecimalTotal's type so that the type information can be serialized too for
+ // 'toBeMerged' scenarios.
+ if (input.getType() != NumberDecimal) {
+ _nonDecimalTotalType = Value::getWidestNumeric(_nonDecimalTotalType, input.getType());
+ }
+
switch (input.getType()) {
case NumberDecimal:
_decimalTotal = _decimalTotal.add(input.getDecimal());
- _isDecimal = true;
break;
case NumberLong:
// Avoid summation using double as that loses precision.
@@ -88,8 +129,7 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) {
_nonDecimalTotal.addDouble(input.getDouble());
break;
default:
- dassert(!input.numeric());
- return;
+ MONGO_UNREACHABLE;
}
_count++;
}
@@ -104,32 +144,39 @@ Decimal128 AccumulatorAvg::_getDecimalTotal() const {
Value AccumulatorAvg::getValue(bool toBeMerged) {
if (toBeMerged) {
- if (_isDecimal)
- return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}});
+ auto partialSumVal =
+ serializePartialSum(_nonDecimalTotalType, _totalType, _nonDecimalTotal, _decimalTotal);
+ if (_totalType == NumberDecimal) {
+ return Value(Document{{subTotalName, _getDecimalTotal()},
+ {countName, _count},
+ {partialSumName, partialSumVal}});
+ }
- double total, error;
- std::tie(total, error) = _nonDecimalTotal.getDoubleDouble();
- return Value(
- Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}});
+ auto [total, error] = _nonDecimalTotal.getDoubleDouble();
+ return Value(Document{{subTotalName, total},
+ {countName, _count},
+ {subTotalErrorName, error},
+ {partialSumName, partialSumVal}});
}
if (_count == 0)
return Value(BSONNULL);
- if (_isDecimal)
+ if (_totalType == NumberDecimal)
return Value(_getDecimalTotal().divide(Decimal128(static_cast<int64_t>(_count))));
return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count));
}
AccumulatorAvg::AccumulatorAvg(ExpressionContext* const expCtx)
- : AccumulatorState(expCtx), _isDecimal(false), _count(0) {
+ : AccumulatorState(expCtx), _count(0) {
// This is a fixed size AccumulatorState so we never need to update this
_memUsageBytes = sizeof(*this);
}
void AccumulatorAvg::reset() {
- _isDecimal = false;
+ _totalType = NumberInt;
+ _nonDecimalTotalType = NumberInt;
_nonDecimalTotal = {};
_decimalTotal = {};
_count = 0;