summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jstests/aggregation/sources/setWindowFields/exp_moving_avg.js221
-rw-r--r--src/mongo/db/pipeline/SConscript1
-rw-r--r--src/mongo/db/pipeline/accumulator.h19
-rw-r--r--src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp95
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_expression.cpp56
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_expression.h62
6 files changed, 453 insertions, 1 deletions
diff --git a/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js b/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js
new file mode 100644
index 00000000000..a0a5c517c91
--- /dev/null
+++ b/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js
@@ -0,0 +1,221 @@
+/**
+ * Test that exponential moving average works as a window function.
+ */
+(function() {
+"use strict";
+
+load("jstests/aggregation/extras/window_function_helpers.js");
+
+const featureEnabled =
+ assert.commandWorked(db.adminCommand({getParameter: 1, featureFlagWindowFunctions: 1}))
+ .featureFlagWindowFunctions.value;
+if (!featureEnabled) {
+ jsTestLog("Skipping test because the window function feature flag is disabled");
+ return;
+}
+
+const coll = db[jsTestName()];
+coll.drop();
+
+// Create a collection of tickers and prices.
+const nDocsPerTicker = 10;
+seedWithTickerData(coll, nDocsPerTicker);
+
+const origDocs = coll.find().sort({_id: 1}).toArray();
+
+// startIndex inclusive, endIndex exclusive.
+function movingAvgForDocs(alpha) {
+ let results = [];
+ let lastVal = null;
+ for (let i = 0; i < origDocs.length; i++) {
+ if (!lastVal) {
+ lastVal = origDocs[i].price;
+ } else {
+ lastVal = origDocs[i].price * alpha + lastVal * (1 - alpha);
+ }
+ results.push(lastVal);
+ }
+ return results;
+}
+
+// Test that $expMovingAvg returns null for windows which do not contain numeric values.
+let results = coll.aggregate([
+ {$addFields: {str: "hiya"}},
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {$expMovingAvg: {input: "$str", N: 2}},
+ }
+ }
+ }
+ ])
+ .toArray();
+for (let index = 0; index < results.length; index++) {
+ assert.eq(null, results[index].expMovAvg);
+}
+// Test simple case with N specified.
+results = coll.aggregate([
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", N: 3},
+ },
+ }
+ }
+ },
+ // Working with NumberDecimals in JS is difficult. Compare doubles instead.
+ {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}}
+ ])
+ .toArray();
+const simpleNResult = movingAvgForDocs(2.0 / (3.0 + 1));
+
+// Same test with manual alpha
+for (let index = 0; index < results.length; index++) {
+ assert.close(simpleNResult[index], results[index].doubleVal, "Disagreement at index " + index);
+}
+
+results = coll.aggregate([
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", alpha: .5},
+ },
+ }
+ }
+ },
+ // Working with NumberDecimals in JS is difficult. Compare doubles instead.
+ {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}}
+ ])
+ .toArray();
+const simpleAlphaResult = movingAvgForDocs(.5);
+for (let index = 0; index < results.length; index++) {
+ assert.close(
+ simpleAlphaResult[index], results[index].doubleVal, "Disagreement at index " + index);
+}
+
+// Succeed with more interesting alpha.
+results = coll.aggregate([
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", alpha: .279},
+ },
+ }
+ }
+ },
+ // Working with NumberDecimals in JS is difficult. Compare doubles instead.
+ {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}}
+ ])
+ .toArray();
+const complexAlphaResult = movingAvgForDocs(.279);
+for (let index = 0; index < results.length; index++) {
+ assert.close(
+ complexAlphaResult[index], results[index].doubleVal, "Disagreement at index " + index);
+}
+
+// Fails if argument type or contents are incorrect.
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", alpha: .5, N: 2},
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", N: .5},
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {input: "$price", N: "food"},
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {str: "$price", N: 2},
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: {str: "$price", N: 2},
+ randomArg: 2,
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+assert.commandFailedWithCode(db.runCommand({
+ aggregate: coll.getName(),
+ pipeline: [
+ {
+ $setWindowFields: {
+ sortBy: {_id: 1},
+ output: {
+ expMovAvg: {
+ $expMovingAvg: "$price",
+ },
+ }
+ }
+ },
+ ]
+}),
+ ErrorCodes.FailedToParse);
+})();
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 14cdc65b850..c9c8ffae144 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -119,6 +119,7 @@ env.Library(
'accumulator_add_to_set.cpp',
'accumulator_avg.cpp',
'accumulator_covariance.cpp',
+ 'accumulator_exp_moving_avg.cpp',
'accumulator_first.cpp',
'accumulator_js_reduce.cpp',
'accumulator_last.cpp',
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index 5107445ed25..71f2dcb7f5a 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -473,4 +473,23 @@ private:
MutableDocument _output;
};
+class AccumulatorExpMovingAvg : public AccumulatorState {
+public:
+ AccumulatorExpMovingAvg(ExpressionContext* const expCtx, Decimal128 alpha);
+
+ void processInternal(const Value& input, bool merging) final;
+ Value getValue(bool toBeMerged) final;
+ const char* getOpName() const final;
+ void reset() final;
+
+ static boost::intrusive_ptr<AccumulatorState> create(ExpressionContext* const expCtx,
+ Decimal128 alpha);
+
+private:
+ Decimal128 _alpha;
+ Decimal128 _currentResult;
+ bool _init = false;
+ bool _isDecimal = false;
+};
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp b/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp
new file mode 100644
index 00000000000..f4393828f82
--- /dev/null
+++ b/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp
@@ -0,0 +1,95 @@
+/**
+ * Copyright (C) 2018-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include <cmath>
+#include <limits>
+
+#include "mongo/db/pipeline/accumulator.h"
+
+#include "mongo/db/exec/document_value/value.h"
+#include "mongo/db/pipeline/accumulation_statement.h"
+#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/pipeline/window_function/window_function_expression.h"
+
+namespace mongo {
+
+using boost::intrusive_ptr;
+
+REGISTER_NON_REMOVABLE_WINDOW_FUNCTION(expMovingAvg,
+ mongo::window_function::ExpressionExpMovingAvg::parse);
+const char* AccumulatorExpMovingAvg::getOpName() const {
+ return "$expMovingAvg";
+}
+
+void AccumulatorExpMovingAvg::processInternal(const Value& input, bool merging) {
+ tassert(5433600, "$expMovingAvg can't be merged", !merging);
+ if (!input.numeric()) {
+ return;
+ }
+ if (input.getType() == BSONType::NumberDecimal) {
+ _isDecimal = true;
+ }
+ auto decimalVal = input.coerceToDecimal();
+ if (!_init) {
+ _currentResult = decimalVal;
+ _init = true;
+ } else {
+ _currentResult = decimalVal.multiply(_alpha).add(
+ _currentResult.multiply(Decimal128(1).subtract(_alpha)));
+ }
+}
+
+intrusive_ptr<AccumulatorState> AccumulatorExpMovingAvg::create(ExpressionContext* const expCtx,
+ Decimal128 alpha) {
+ return new AccumulatorExpMovingAvg(expCtx, alpha);
+}
+
+Value AccumulatorExpMovingAvg::getValue(bool toBeMerged) {
+ tassert(5433601, "$expMovingAvg can't be merged", !toBeMerged);
+ if (!_init) {
+ return Value(BSONNULL);
+ }
+ if (!_isDecimal) {
+ return Value(_currentResult.toDouble());
+ }
+ return Value(_currentResult);
+}
+
+AccumulatorExpMovingAvg::AccumulatorExpMovingAvg(ExpressionContext* const expCtx, Decimal128 alpha)
+ : AccumulatorState(expCtx), _alpha(alpha) {
+ _memUsageBytes = sizeof(*this);
+}
+
+void AccumulatorExpMovingAvg::reset() {
+ _memUsageBytes = sizeof(*this);
+ _init = false;
+}
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.cpp b/src/mongo/db/pipeline/window_function/window_function_expression.cpp
index 27522e9b790..9f495295c2e 100644
--- a/src/mongo/db/pipeline/window_function/window_function_expression.cpp
+++ b/src/mongo/db/pipeline/window_function/window_function_expression.cpp
@@ -66,6 +66,62 @@ void Expression::registerParser(std::string functionName, Parser parser) {
parserMap.emplace(std::move(functionName), std::move(parser));
}
+
+boost::intrusive_ptr<Expression> ExpressionExpMovingAvg::parse(
+ BSONObj obj, const boost::optional<SortPattern>& sortBy, ExpressionContext* expCtx) {
+ // 'obj' is something like '{$expMovingAvg: {input: <arg>, <N/alpha>: <int/float>}}'
+ boost::optional<StringData> accumulatorName;
+ boost::intrusive_ptr<::mongo::Expression> input;
+ uassert(ErrorCodes::FailedToParse,
+ "$expMovingAvg must have exactly one argument that is an object",
+ obj.nFields() == 1 && obj.hasField(kAccName) &&
+ obj[kAccName].type() == BSONType::Object);
+ auto subObj = obj[kAccName].embeddedObject();
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "$expMovingAvg sub object must have exactly two fields: An '"
+ << kInputArg << "' field, and either an '" << kNArg << "' field or an '"
+ << kAlphaArg << "' field",
+ subObj.nFields() == 2 && subObj.hasField(kInputArg));
+ input =
+ ::mongo::Expression::parseOperand(expCtx, subObj[kInputArg], expCtx->variablesParseState);
+ // ExpMovingAvg is always unbounded to current.
+ WindowBounds bounds = WindowBounds{
+ WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Current{}}};
+ if (subObj.hasField(kNArg)) {
+ auto nVal = subObj[kNArg];
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "'" << kNArg << "' field must be an integer, but found type "
+ << nVal.type(),
+ nVal.isNumber());
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "'" << kNArg << "' field must be an integer, but found " << nVal
+ << ". To use a non-integer, use the '" << kAlphaArg
+ << "' argument instead",
+ nVal.safeNumberDouble() == floor(nVal.safeNumberDouble()));
+ auto nNum = nVal.safeNumberLong();
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "'" << kNArg << "' must be greater than zero. Got " << nNum,
+ nNum > 0);
+ return make_intrusive<ExpressionExpMovingAvg>(
+ expCtx, std::string(kAccName), std::move(input), std::move(bounds), nNum);
+ } else if (subObj.hasField(kAlphaArg)) {
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "'" << kAlphaArg << "' must be a number",
+ subObj[kAlphaArg].isNumber());
+ return make_intrusive<ExpressionExpMovingAvg>(expCtx,
+ std::string(kAccName),
+ std::move(input),
+ std::move(bounds),
+ subObj[kAlphaArg].numberDecimal());
+ } else {
+ uasserted(ErrorCodes::FailedToParse,
+ str::stream() << "Got unrecognized field in $expMovingAvg"
+ << "$expMovingAvg sub object must have exactly two fields: An '"
+ << kInputArg << "' field, and either an '" << kNArg
+ << "' field or an '" << kAlphaArg << "' field");
+ }
+}
+
MONGO_INITIALIZER(windowFunctionExpressionMap)(InitializerContext*) {
// Nothing to do. This initializer exists to tie together all the individual initializers
// defined by REGISTER_NON_REMOVABLE_WINDOW_FUNCTION and REGISTER_REMOVABLE_WINDOW_FUNCTION
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h
index ead4de18fc1..ae0634dc294 100644
--- a/src/mongo/db/pipeline/window_function/window_function_expression.h
+++ b/src/mongo/db/pipeline/window_function/window_function_expression.h
@@ -256,7 +256,6 @@ public:
// Rank based accumulators are always unbounded to current.
WindowBounds bounds = WindowBounds{
WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Current{}}};
- boost::intrusive_ptr<::mongo::Expression> input;
auto arg = obj.firstElement();
auto argName = arg.fieldNameStringData();
if (parserMap.find(argName) != parserMap.end()) {
@@ -311,4 +310,65 @@ public:
}
};
+class ExpressionExpMovingAvg : public Expression {
+public:
+ static constexpr StringData kAccName = "$expMovingAvg"_sd;
+ static constexpr StringData kInputArg = "input"_sd;
+ static constexpr StringData kNArg = "N"_sd;
+ static constexpr StringData kAlphaArg = "alpha"_sd;
+ static boost::intrusive_ptr<Expression> parse(BSONObj obj,
+ const boost::optional<SortPattern>& sortBy,
+ ExpressionContext* expCtx);
+
+ ExpressionExpMovingAvg(ExpressionContext* expCtx,
+ std::string accumulatorName,
+ boost::intrusive_ptr<::mongo::Expression> input,
+ WindowBounds bounds,
+ long long nValue)
+ : Expression(expCtx, std::move(accumulatorName), std::move(input), std::move(bounds)),
+ _N(nValue) {}
+
+ ExpressionExpMovingAvg(ExpressionContext* expCtx,
+ std::string accumulatorName,
+ boost::intrusive_ptr<::mongo::Expression> input,
+ WindowBounds bounds,
+ Decimal128 alpha)
+ : Expression(expCtx, std::move(accumulatorName), std::move(input), std::move(bounds)),
+ _alpha(alpha) {}
+
+ boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final {
+ if (_N) {
+ return AccumulatorExpMovingAvg::create(
+ _expCtx, Decimal128(2).divide(Decimal128(_N.get()).add(Decimal128(1))));
+ } else if (_alpha) {
+ return AccumulatorExpMovingAvg::create(_expCtx, _alpha.get());
+ }
+ tasserted(5433602, "ExpMovingAvg neither N nor alpha was set");
+ }
+
+ std::unique_ptr<WindowFunctionState> buildRemovable() const final {
+ tasserted(5433603,
+ str::stream() << "Window function " << _accumulatorName
+ << " is not supported with a removable window");
+ }
+
+ Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final {
+ MutableDocument subObj;
+ tassert(5433604, "ExpMovingAvg neither N nor alpha was set", _N || _alpha);
+ if (_N) {
+ subObj[kNArg] = Value(_N.get());
+ } else {
+ subObj[kAlphaArg] = Value(_alpha.get());
+ }
+ subObj[kInputArg] = _input->serialize(static_cast<bool>(explain));
+ MutableDocument outerObj;
+ outerObj[kAccName] = subObj.freezeToValue();
+ return outerObj.freezeToValue();
+ }
+
+protected:
+ boost::optional<long long> _N;
+ boost::optional<Decimal128> _alpha;
+};
+
} // namespace mongo::window_function