diff options
Diffstat (limited to 'src/mongo/db/pipeline/accumulator_avg.cpp')
-rw-r--r-- | src/mongo/db/pipeline/accumulator_avg.cpp | 87 |
1 files changed, 67 insertions, 20 deletions
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index daddd56f502..a12b0da4361 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -29,6 +29,7 @@ #include "mongo/platform/basic.h" +#include "mongo/db/exec/sbe/accumulator_sum_value_enum.h" #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/exec/document_value/document.h" @@ -49,33 +50,73 @@ REGISTER_STABLE_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse REGISTER_REMOVABLE_WINDOW_FUNCTION(avg, AccumulatorAvg, WindowFunctionAvg); namespace { +// TODO SERVER-64227 Remove 'subTotal' and 'subTotalError' fields when we branch for 6.1 because all +// nodes in a sharded cluster would use the new data format. const char subTotalName[] = "subTotal"; const char subTotalErrorName[] = "subTotalError"; // Used for extra precision +const char partialSumName[] = "ps"; // Used for the full state of partial sum const char countName[] = "count"; } // namespace +void applyPartialSum(const std::vector<Value>& arr, + BSONType& nonDecimalTotalType, + BSONType& totalType, + DoubleDoubleSummation& nonDecimalTotal, + Decimal128& decimalTotal); + +Value serializePartialSum(BSONType nonDecimalTotalType, + BSONType totalType, + const DoubleDoubleSummation& nonDecimalTotal, + const Decimal128& decimalTotal); + void AccumulatorAvg::processInternal(const Value& input, bool merging) { if (merging) { // We expect an object that contains both a subtotal and a count. Additionally there may // be an error value, that allows for additional precision. // 'input' is what getValue(true) produced below. verify(input.getType() == Object); - // We're recursively adding the subtotal to get the proper type treatment, but this only - // increments the count by one, so adjust the count afterwards. Similarly for 'error'. - processInternal(input[subTotalName], false); - _count += input[countName].getLong() - 1; - Value error = input[subTotalErrorName]; - if (!error.missing()) { - processInternal(error, false); - _count--; // The error correction only adjusts the total, not the number of items. + + // TODO SERVER-64227 Remove 'if' block when we branch for 6.1 because all nodes in a sharded + // cluster would use the new data format. + if (auto partialSumVal = input[partialSumName]; partialSumVal.missing()) { + // We're recursively adding the subtotal to get the proper type treatment, but this only + // increments the count by one, so adjust the count afterwards. Similarly for 'error'. + processInternal(input[subTotalName], false); + _count += input[countName].getLong() - 1; + Value error = input[subTotalErrorName]; + if (!error.missing()) { + processInternal(error, false); + _count--; // The error correction only adjusts the total, not the number of items. + } + } else { + // The merge-side must be ready to process the full state of a partial sum from a + // shard-side if a shard chooses to do so. See Accumulator::getValue() for details. + applyPartialSum(partialSumVal.getArray(), + _nonDecimalTotalType, + _totalType, + _nonDecimalTotal, + _decimalTotal); + _count += input[countName].getLong(); } + return; } + if (!input.numeric()) { + return; + } + + _totalType = Value::getWidestNumeric(_totalType, input.getType()); + + // Keep the nonDecimalTotal's type so that the type information can be serialized too for + // 'toBeMerged' scenarios. + if (input.getType() != NumberDecimal) { + _nonDecimalTotalType = Value::getWidestNumeric(_nonDecimalTotalType, input.getType()); + } + switch (input.getType()) { case NumberDecimal: _decimalTotal = _decimalTotal.add(input.getDecimal()); - _isDecimal = true; break; case NumberLong: // Avoid summation using double as that loses precision. @@ -88,8 +129,7 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) { _nonDecimalTotal.addDouble(input.getDouble()); break; default: - dassert(!input.numeric()); - return; + MONGO_UNREACHABLE; } _count++; } @@ -104,32 +144,39 @@ Decimal128 AccumulatorAvg::_getDecimalTotal() const { Value AccumulatorAvg::getValue(bool toBeMerged) { if (toBeMerged) { - if (_isDecimal) - return Value(Document{{subTotalName, _getDecimalTotal()}, {countName, _count}}); + auto partialSumVal = + serializePartialSum(_nonDecimalTotalType, _totalType, _nonDecimalTotal, _decimalTotal); + if (_totalType == NumberDecimal) { + return Value(Document{{subTotalName, _getDecimalTotal()}, + {countName, _count}, + {partialSumName, partialSumVal}}); + } - double total, error; - std::tie(total, error) = _nonDecimalTotal.getDoubleDouble(); - return Value( - Document{{subTotalName, total}, {countName, _count}, {subTotalErrorName, error}}); + auto [total, error] = _nonDecimalTotal.getDoubleDouble(); + return Value(Document{{subTotalName, total}, + {countName, _count}, + {subTotalErrorName, error}, + {partialSumName, partialSumVal}}); } if (_count == 0) return Value(BSONNULL); - if (_isDecimal) + if (_totalType == NumberDecimal) return Value(_getDecimalTotal().divide(Decimal128(static_cast<int64_t>(_count)))); return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count)); } AccumulatorAvg::AccumulatorAvg(ExpressionContext* const expCtx) - : AccumulatorState(expCtx), _isDecimal(false), _count(0) { + : AccumulatorState(expCtx), _count(0) { // This is a fixed size AccumulatorState so we never need to update this _memUsageBytes = sizeof(*this); } void AccumulatorAvg::reset() { - _isDecimal = false; + _totalType = NumberInt; + _nonDecimalTotalType = NumberInt; _nonDecimalTotal = {}; _decimalTotal = {}; _count = 0; |