diff options
6 files changed, 453 insertions, 1 deletions
diff --git a/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js b/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js new file mode 100644 index 00000000000..a0a5c517c91 --- /dev/null +++ b/jstests/aggregation/sources/setWindowFields/exp_moving_avg.js @@ -0,0 +1,221 @@ +/** + * Test that exponential moving average works as a window function. + */ +(function() { +"use strict"; + +load("jstests/aggregation/extras/window_function_helpers.js"); + +const featureEnabled = + assert.commandWorked(db.adminCommand({getParameter: 1, featureFlagWindowFunctions: 1})) + .featureFlagWindowFunctions.value; +if (!featureEnabled) { + jsTestLog("Skipping test because the window function feature flag is disabled"); + return; +} + +const coll = db[jsTestName()]; +coll.drop(); + +// Create a collection of tickers and prices. +const nDocsPerTicker = 10; +seedWithTickerData(coll, nDocsPerTicker); + +const origDocs = coll.find().sort({_id: 1}).toArray(); + +// startIndex inclusive, endIndex exclusive. +function movingAvgForDocs(alpha) { + let results = []; + let lastVal = null; + for (let i = 0; i < origDocs.length; i++) { + if (!lastVal) { + lastVal = origDocs[i].price; + } else { + lastVal = origDocs[i].price * alpha + lastVal * (1 - alpha); + } + results.push(lastVal); + } + return results; +} + +// Test that $expMovingAvg returns null for windows which do not contain numeric values. +let results = coll.aggregate([ + {$addFields: {str: "hiya"}}, + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: {$expMovingAvg: {input: "$str", N: 2}}, + } + } + } + ]) + .toArray(); +for (let index = 0; index < results.length; index++) { + assert.eq(null, results[index].expMovAvg); +} +// Test simple case with N specified. +results = coll.aggregate([ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", N: 3}, + }, + } + } + }, + // Working with NumberDecimals in JS is difficult. Compare doubles instead. + {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}} + ]) + .toArray(); +const simpleNResult = movingAvgForDocs(2.0 / (3.0 + 1)); + +// Same test with manual alpha +for (let index = 0; index < results.length; index++) { + assert.close(simpleNResult[index], results[index].doubleVal, "Disagreement at index " + index); +} + +results = coll.aggregate([ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", alpha: .5}, + }, + } + } + }, + // Working with NumberDecimals in JS is difficult. Compare doubles instead. + {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}} + ]) + .toArray(); +const simpleAlphaResult = movingAvgForDocs(.5); +for (let index = 0; index < results.length; index++) { + assert.close( + simpleAlphaResult[index], results[index].doubleVal, "Disagreement at index " + index); +} + +// Succeed with more interesting alpha. +results = coll.aggregate([ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", alpha: .279}, + }, + } + } + }, + // Working with NumberDecimals in JS is difficult. Compare doubles instead. + {$addFields: {doubleVal: {$toDouble: "$expMovAvg"}}} + ]) + .toArray(); +const complexAlphaResult = movingAvgForDocs(.279); +for (let index = 0; index < results.length; index++) { + assert.close( + complexAlphaResult[index], results[index].doubleVal, "Disagreement at index " + index); +} + +// Fails if argument type or contents are incorrect. +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", alpha: .5, N: 2}, + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", N: .5}, + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {input: "$price", N: "food"}, + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {str: "$price", N: 2}, + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: {str: "$price", N: 2}, + randomArg: 2, + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +assert.commandFailedWithCode(db.runCommand({ + aggregate: coll.getName(), + pipeline: [ + { + $setWindowFields: { + sortBy: {_id: 1}, + output: { + expMovAvg: { + $expMovingAvg: "$price", + }, + } + } + }, + ] +}), + ErrorCodes.FailedToParse); +})(); diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 14cdc65b850..c9c8ffae144 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -119,6 +119,7 @@ env.Library( 'accumulator_add_to_set.cpp', 'accumulator_avg.cpp', 'accumulator_covariance.cpp', + 'accumulator_exp_moving_avg.cpp', 'accumulator_first.cpp', 'accumulator_js_reduce.cpp', 'accumulator_last.cpp', diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index 5107445ed25..71f2dcb7f5a 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -473,4 +473,23 @@ private: MutableDocument _output; }; +class AccumulatorExpMovingAvg : public AccumulatorState { +public: + AccumulatorExpMovingAvg(ExpressionContext* const expCtx, Decimal128 alpha); + + void processInternal(const Value& input, bool merging) final; + Value getValue(bool toBeMerged) final; + const char* getOpName() const final; + void reset() final; + + static boost::intrusive_ptr<AccumulatorState> create(ExpressionContext* const expCtx, + Decimal128 alpha); + +private: + Decimal128 _alpha; + Decimal128 _currentResult; + bool _init = false; + bool _isDecimal = false; +}; + } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp b/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp new file mode 100644 index 00000000000..f4393828f82 --- /dev/null +++ b/src/mongo/db/pipeline/accumulator_exp_moving_avg.cpp @@ -0,0 +1,95 @@ +/** + * Copyright (C) 2018-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include <cmath> +#include <limits> + +#include "mongo/db/pipeline/accumulator.h" + +#include "mongo/db/exec/document_value/value.h" +#include "mongo/db/pipeline/accumulation_statement.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/window_function/window_function_expression.h" + +namespace mongo { + +using boost::intrusive_ptr; + +REGISTER_NON_REMOVABLE_WINDOW_FUNCTION(expMovingAvg, + mongo::window_function::ExpressionExpMovingAvg::parse); +const char* AccumulatorExpMovingAvg::getOpName() const { + return "$expMovingAvg"; +} + +void AccumulatorExpMovingAvg::processInternal(const Value& input, bool merging) { + tassert(5433600, "$expMovingAvg can't be merged", !merging); + if (!input.numeric()) { + return; + } + if (input.getType() == BSONType::NumberDecimal) { + _isDecimal = true; + } + auto decimalVal = input.coerceToDecimal(); + if (!_init) { + _currentResult = decimalVal; + _init = true; + } else { + _currentResult = decimalVal.multiply(_alpha).add( + _currentResult.multiply(Decimal128(1).subtract(_alpha))); + } +} + +intrusive_ptr<AccumulatorState> AccumulatorExpMovingAvg::create(ExpressionContext* const expCtx, + Decimal128 alpha) { + return new AccumulatorExpMovingAvg(expCtx, alpha); +} + +Value AccumulatorExpMovingAvg::getValue(bool toBeMerged) { + tassert(5433601, "$expMovingAvg can't be merged", !toBeMerged); + if (!_init) { + return Value(BSONNULL); + } + if (!_isDecimal) { + return Value(_currentResult.toDouble()); + } + return Value(_currentResult); +} + +AccumulatorExpMovingAvg::AccumulatorExpMovingAvg(ExpressionContext* const expCtx, Decimal128 alpha) + : AccumulatorState(expCtx), _alpha(alpha) { + _memUsageBytes = sizeof(*this); +} + +void AccumulatorExpMovingAvg::reset() { + _memUsageBytes = sizeof(*this); + _init = false; +} +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.cpp b/src/mongo/db/pipeline/window_function/window_function_expression.cpp index 27522e9b790..9f495295c2e 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.cpp +++ b/src/mongo/db/pipeline/window_function/window_function_expression.cpp @@ -66,6 +66,62 @@ void Expression::registerParser(std::string functionName, Parser parser) { parserMap.emplace(std::move(functionName), std::move(parser)); } + +boost::intrusive_ptr<Expression> ExpressionExpMovingAvg::parse( + BSONObj obj, const boost::optional<SortPattern>& sortBy, ExpressionContext* expCtx) { + // 'obj' is something like '{$expMovingAvg: {input: <arg>, <N/alpha>: <int/float>}}' + boost::optional<StringData> accumulatorName; + boost::intrusive_ptr<::mongo::Expression> input; + uassert(ErrorCodes::FailedToParse, + "$expMovingAvg must have exactly one argument that is an object", + obj.nFields() == 1 && obj.hasField(kAccName) && + obj[kAccName].type() == BSONType::Object); + auto subObj = obj[kAccName].embeddedObject(); + uassert(ErrorCodes::FailedToParse, + str::stream() << "$expMovingAvg sub object must have exactly two fields: An '" + << kInputArg << "' field, and either an '" << kNArg << "' field or an '" + << kAlphaArg << "' field", + subObj.nFields() == 2 && subObj.hasField(kInputArg)); + input = + ::mongo::Expression::parseOperand(expCtx, subObj[kInputArg], expCtx->variablesParseState); + // ExpMovingAvg is always unbounded to current. + WindowBounds bounds = WindowBounds{ + WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Current{}}}; + if (subObj.hasField(kNArg)) { + auto nVal = subObj[kNArg]; + uassert(ErrorCodes::FailedToParse, + str::stream() << "'" << kNArg << "' field must be an integer, but found type " + << nVal.type(), + nVal.isNumber()); + uassert(ErrorCodes::FailedToParse, + str::stream() << "'" << kNArg << "' field must be an integer, but found " << nVal + << ". To use a non-integer, use the '" << kAlphaArg + << "' argument instead", + nVal.safeNumberDouble() == floor(nVal.safeNumberDouble())); + auto nNum = nVal.safeNumberLong(); + uassert(ErrorCodes::FailedToParse, + str::stream() << "'" << kNArg << "' must be greater than zero. Got " << nNum, + nNum > 0); + return make_intrusive<ExpressionExpMovingAvg>( + expCtx, std::string(kAccName), std::move(input), std::move(bounds), nNum); + } else if (subObj.hasField(kAlphaArg)) { + uassert(ErrorCodes::FailedToParse, + str::stream() << "'" << kAlphaArg << "' must be a number", + subObj[kAlphaArg].isNumber()); + return make_intrusive<ExpressionExpMovingAvg>(expCtx, + std::string(kAccName), + std::move(input), + std::move(bounds), + subObj[kAlphaArg].numberDecimal()); + } else { + uasserted(ErrorCodes::FailedToParse, + str::stream() << "Got unrecognized field in $expMovingAvg" + << "$expMovingAvg sub object must have exactly two fields: An '" + << kInputArg << "' field, and either an '" << kNArg + << "' field or an '" << kAlphaArg << "' field"); + } +} + MONGO_INITIALIZER(windowFunctionExpressionMap)(InitializerContext*) { // Nothing to do. This initializer exists to tie together all the individual initializers // defined by REGISTER_NON_REMOVABLE_WINDOW_FUNCTION and REGISTER_REMOVABLE_WINDOW_FUNCTION diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h index ead4de18fc1..ae0634dc294 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.h +++ b/src/mongo/db/pipeline/window_function/window_function_expression.h @@ -256,7 +256,6 @@ public: // Rank based accumulators are always unbounded to current. WindowBounds bounds = WindowBounds{ WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Current{}}}; - boost::intrusive_ptr<::mongo::Expression> input; auto arg = obj.firstElement(); auto argName = arg.fieldNameStringData(); if (parserMap.find(argName) != parserMap.end()) { @@ -311,4 +310,65 @@ public: } }; +class ExpressionExpMovingAvg : public Expression { +public: + static constexpr StringData kAccName = "$expMovingAvg"_sd; + static constexpr StringData kInputArg = "input"_sd; + static constexpr StringData kNArg = "N"_sd; + static constexpr StringData kAlphaArg = "alpha"_sd; + static boost::intrusive_ptr<Expression> parse(BSONObj obj, + const boost::optional<SortPattern>& sortBy, + ExpressionContext* expCtx); + + ExpressionExpMovingAvg(ExpressionContext* expCtx, + std::string accumulatorName, + boost::intrusive_ptr<::mongo::Expression> input, + WindowBounds bounds, + long long nValue) + : Expression(expCtx, std::move(accumulatorName), std::move(input), std::move(bounds)), + _N(nValue) {} + + ExpressionExpMovingAvg(ExpressionContext* expCtx, + std::string accumulatorName, + boost::intrusive_ptr<::mongo::Expression> input, + WindowBounds bounds, + Decimal128 alpha) + : Expression(expCtx, std::move(accumulatorName), std::move(input), std::move(bounds)), + _alpha(alpha) {} + + boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final { + if (_N) { + return AccumulatorExpMovingAvg::create( + _expCtx, Decimal128(2).divide(Decimal128(_N.get()).add(Decimal128(1)))); + } else if (_alpha) { + return AccumulatorExpMovingAvg::create(_expCtx, _alpha.get()); + } + tasserted(5433602, "ExpMovingAvg neither N nor alpha was set"); + } + + std::unique_ptr<WindowFunctionState> buildRemovable() const final { + tasserted(5433603, + str::stream() << "Window function " << _accumulatorName + << " is not supported with a removable window"); + } + + Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final { + MutableDocument subObj; + tassert(5433604, "ExpMovingAvg neither N nor alpha was set", _N || _alpha); + if (_N) { + subObj[kNArg] = Value(_N.get()); + } else { + subObj[kAlphaArg] = Value(_alpha.get()); + } + subObj[kInputArg] = _input->serialize(static_cast<bool>(explain)); + MutableDocument outerObj; + outerObj[kAccName] = subObj.freezeToValue(); + return outerObj.freezeToValue(); + } + +protected: + boost::optional<long long> _N; + boost::optional<Decimal128> _alpha; +}; + } // namespace mongo::window_function |