summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/accumulator_sum.cpp
diff options
context:
space:
mode:
authorHartek Sabharwal <hartek.sabharwal@mongodb.com>2021-03-17 16:17:27 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-03-17 18:27:38 +0000
commit0afd76f0c80b7e33de15fd2bce7e409dd88318bc (patch)
tree5ab987247700536e45c748cc6494dea6edd87c49 /src/mongo/db/pipeline/accumulator_sum.cpp
parent8ae44d4f3557e08d8a2e54f92e7395b46ea62859 (diff)
downloadmongo-0afd76f0c80b7e33de15fd2bce7e409dd88318bc.tar.gz
SERVER-53713 Implement removable $sum and $avg window function
Diffstat (limited to 'src/mongo/db/pipeline/accumulator_sum.cpp')
-rw-r--r--src/mongo/db/pipeline/accumulator_sum.cpp10
1 files changed, 6 insertions, 4 deletions
diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp
index 0a1d36feb2b..bc1eb451d9a 100644
--- a/src/mongo/db/pipeline/accumulator_sum.cpp
+++ b/src/mongo/db/pipeline/accumulator_sum.cpp
@@ -38,6 +38,7 @@
#include "mongo/db/pipeline/accumulation_statement.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/window_function/window_function_expression.h"
+#include "mongo/db/pipeline/window_function/window_function_sum.h"
#include "mongo/util/summation.h"
namespace mongo {
@@ -46,8 +47,7 @@ using boost::intrusive_ptr;
REGISTER_ACCUMULATOR(sum, genericParseSingleExpressionAccumulator<AccumulatorSum>);
REGISTER_EXPRESSION(sum, ExpressionFromAccumulator<AccumulatorSum>::parse);
-REGISTER_NON_REMOVABLE_WINDOW_FUNCTION(
- sum, window_function::ExpressionFromAccumulator<AccumulatorSum>::parse);
+REGISTER_REMOVABLE_WINDOW_FUNCTION(sum, AccumulatorSum, WindowFunctionSum);
REGISTER_ACCUMULATOR(count, parseCountAccumulator);
const char* AccumulatorSum::getOpName() const {
@@ -74,9 +74,11 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) {
// Upgrade to the widest type required to hold the result.
totalType = Value::getWidestNumeric(totalType, input.getType());
switch (input.getType()) {
- case NumberInt:
case NumberLong:
- nonDecimalTotal.addLong(input.coerceToLong());
+ nonDecimalTotal.addLong(input.getLong());
+ break;
+ case NumberInt:
+ nonDecimalTotal.addInt(input.getInt());
break;
case NumberDouble:
nonDecimalTotal.addDouble(input.getDouble());