diff options
author | David Percy <david.percy@mongodb.com> | 2020-01-17 16:20:06 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2020-02-27 20:44:41 +0000 |
commit | 606fbf8eac896b0b4ed26e921b7f6bf1f73f5511 (patch) | |
tree | 4855ab6890e429ff79ffdf867d2b973361b62b00 /src/mongo | |
parent | 5e57c0b0f7505035c37179d100fdd43ef2b6cc36 (diff) | |
download | mongo-606fbf8eac896b0b4ed26e921b7f6bf1f73f5511.tar.gz |
SERVER-45447 Add $accumulator for user-defined Javascript accumulators
Diffstat (limited to 'src/mongo')
23 files changed, 572 insertions, 162 deletions
diff --git a/src/mongo/db/commands/mr_common.cpp b/src/mongo/db/commands/mr_common.cpp index 6d30a985125..7343a63ac89 100644 --- a/src/mongo/db/commands/mr_common.cpp +++ b/src/mongo/db/commands/mr_common.cpp @@ -123,15 +123,19 @@ auto translateMap(boost::intrusive_ptr<ExpressionContext> expCtx, std::string co } auto translateReduce(boost::intrusive_ptr<ExpressionContext> expCtx, std::string code) { - auto accumulatorArgument = - ExpressionFieldPath::parse(expCtx, "$emits", expCtx->variablesParseState); - auto reduceFactory = [expCtx, funcSource = code]() { + auto initializer = ExpressionArray::create(expCtx, {}); + auto argument = ExpressionFieldPath::parse(expCtx, "$emits", expCtx->variablesParseState); + auto reduceFactory = [expCtx, funcSource = std::move(code)]() { return AccumulatorInternalJsReduce::create(expCtx, funcSource); }; - AccumulationStatement jsReduce("value", std::move(accumulatorArgument), reduceFactory); - auto groupExpr = ExpressionFieldPath::parse(expCtx, "$emits.k", expCtx->variablesParseState); + AccumulationStatement jsReduce("value", + AccumulationExpression(std::move(initializer), + std::move(argument), + std::move(reduceFactory))); + auto groupKeyExpression = + ExpressionFieldPath::parse(expCtx, "$emits.k", expCtx->variablesParseState); return DocumentSourceGroup::create(expCtx, - std::move(groupExpr), + std::move(groupKeyExpression), make_vector<AccumulationStatement>(std::move(jsReduce)), boost::none); } diff --git a/src/mongo/db/pipeline/accumulation_statement.cpp b/src/mongo/db/pipeline/accumulation_statement.cpp index 18595477227..72b2b80a6c6 100644 --- a/src/mongo/db/pipeline/accumulation_statement.cpp +++ b/src/mongo/db/pipeline/accumulation_statement.cpp @@ -65,8 +65,8 @@ AccumulationStatement::Parser& AccumulationStatement::getParser(StringData name) return it->second; } -boost::intrusive_ptr<Accumulator> AccumulationStatement::makeAccumulator() const { - return _factory(); +boost::intrusive_ptr<AccumulatorState> AccumulationStatement::makeAccumulator() const { + return expr.factory(); } AccumulationStatement AccumulationStatement::parseAccumulationStatement( @@ -98,9 +98,10 @@ AccumulationStatement AccumulationStatement::parseAccumulationStatement( specElem.type() != BSONType::Array); auto&& parser = AccumulationStatement::getParser(accName); - auto [expression, factory] = parser(expCtx, specElem, vps); + auto [initializer, argument, factory] = parser(expCtx, specElem, vps); - return AccumulationStatement(fieldName.toString(), expression, factory); + return AccumulationStatement(fieldName.toString(), + AccumulationExpression(initializer, argument, factory)); } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h index 91ac3a1aff3..fff666141d3 100644 --- a/src/mongo/db/pipeline/accumulation_statement.h +++ b/src/mongo/db/pipeline/accumulation_statement.h @@ -38,8 +38,8 @@ namespace mongo { /** - * Registers an Accumulator to have the name 'key'. When an accumulator with name '$key' is found - * during parsing, 'factory' will be called to construct the Accumulator. + * Registers an AccumulatorState to have the name 'key'. When an accumulator with name '$key' is + * found during parsing, 'factory' will be called to construct the AccumulatorState. * * As an example, if your accumulator looks like {"$foo": <args>}, with a factory method 'create', * you would add this line: @@ -52,26 +52,102 @@ namespace mongo { } /** + * AccumulatorExpression represents the right-hand side of an AccumulationStatement. Note this is + * different from Expression; they are different nonterminals in the grammar. + * + * For example, in + * {$group: { + * _id: 1, + * count: {$sum: {$size: "$tags"}} + * }} + * + * we would say: + * The AccumulationStatement is count: {$sum: {$size: "$tags"}} + * The AccumulationExpression is {$sum: {$size: "$tags"}} + * The AccumulatorState::Factory is $sum + * The argument Expression is {$size: "$tags"} + * There is no initializer Expression. + * + * "$sum" corresponds to an AccumulatorState::Factory rather than AccumulatorState because + * AccumulatorState is an execution concept, not an AST concept: each instance of AccumulatorState + * contains intermediate values being accumulated. + * + * Like most accumulators, $sum does not require or accept an initializer Expression. At time of + * writing, only user-defined accumulators accept an initializer. + * + * For example, in: + * {$group: { + * _id: {cc: "$country_code"}, + * top_stories: {$accumulator: { + * init: function(cc) { ... }, + * initArgs: ["$cc"], + * accumulate: function(state, title, upvotes) { ... }, + * accumulateArgs: ["$title", "$upvotes"], + * merge: function(state1, state2) { ... }, + * lang: "js", + * }} + * }} + * + * we would say: + * The AccumulationStatement is top_stories: {$accumulator: ... } + * The AccumulationExpression is {$accumulator: ... } + * The argument Expression is ["$cc"] + * The initializer Expression is ["$title", "$upvotes"] + * The AccumulatorState::Factory holds all the other arguments to $accumulator. + * + */ +struct AccumulationExpression { + AccumulationExpression(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, + AccumulatorState::Factory factory) + : initializer(initializer), argument(argument), factory(factory) { + invariant(this->initializer); + invariant(this->argument); + } + + // The expression to use to obtain the input to the accumulator. + boost::intrusive_ptr<Expression> initializer; + + // An expression evaluated once per input document, and passed to AccumulatorState::process. + boost::intrusive_ptr<Expression> argument; + + // Constructs an AccumulatorState to do actual accumulation. + boost::intrusive_ptr<AccumulatorState> makeAccumulator() const; + + // A no argument function object that can be called to create an AccumulatorState. + const AccumulatorState::Factory factory; +}; + +/** + * A default parser for any accumulator that only takes a single expression as an argument. Returns + * the expression to be evaluated by the accumulator and an AccumulatorState::Factory. + */ +template <class AccName> +AccumulationExpression genericParseSingleExpressionAccumulator( + boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps) { + auto initializer = ExpressionConstant::create(expCtx, Value(BSONNULL)); + auto argument = Expression::parseOperand(expCtx, elem, vps); + return {initializer, argument, [expCtx]() { return AccName::create(expCtx); }}; +} + +/** * A class representing a user-specified accumulation, including the field name to put the * accumulated result in, which accumulator to use, and the expression used to obtain the input to - * the Accumulator. + * the AccumulatorState. */ class AccumulationStatement { public: - using Parser = std::function<std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory>( + using Parser = std::function<AccumulationExpression( boost::intrusive_ptr<ExpressionContext>, BSONElement, VariablesParseState)>; - AccumulationStatement(std::string fieldName, - boost::intrusive_ptr<Expression> expression, - Accumulator::Factory factory) - : fieldName(std::move(fieldName)), - expression(std::move(expression)), - _factory(std::move(factory)) {} + + AccumulationStatement(std::string fieldName, AccumulationExpression expr) + : fieldName(std::move(fieldName)), expr(std::move(expr)) {} /** * Parses a BSONElement that is an accumulated field, and returns an AccumulationStatement for * that accumulated field. * - * Throws a AssertionException if parsing fails. + * Throws an AssertionException if parsing fails. */ static AccumulationStatement parseAccumulationStatement( const boost::intrusive_ptr<ExpressionContext>& expCtx, @@ -79,9 +155,9 @@ public: const VariablesParseState& vps); /** - * Registers an Accumulator with a parsing function, so that when an accumulator with the given - * name is encountered during parsing, we will know to call 'factory' to construct that - * Accumulator. + * Registers an AccumulatorState with a parsing function, so that when an accumulator with the + * given name is encountered during parsing, we will know to call 'factory' to construct that + * AccumulatorState. * * DO NOT call this method directly. Instead, use the REGISTER_ACCUMULATOR macro defined in this * file. @@ -90,22 +166,17 @@ public: /** * Retrieves the Parser for the accumulator specified by the given name, and raises an error if - * there is no such Accumulator registered. + * there is no such AccumulatorState registered. */ static Parser& getParser(StringData name); // The field name is used to store the results of the accumulation in a result document. std::string fieldName; - // The expression to use to obtain the input to the accumulator. - boost::intrusive_ptr<Expression> expression; - - // Constructs an Accumulator to do actual accumulation. - boost::intrusive_ptr<Accumulator> makeAccumulator() const; + AccumulationExpression expr; -private: - // A no argument function object that can be called to create an Accumulator. - const Accumulator::Factory _factory; + // Constructs an AccumulatorState to do actual accumulation. + boost::intrusive_ptr<AccumulatorState> makeAccumulator() const; }; diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index cb58bc7a3b1..e59ccb543da 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -51,23 +51,28 @@ namespace mongo { * This enum indicates which documents an accumulator needs to see in order to compute its output. */ enum class AccumulatorDocumentsNeeded { - // Accumulator needs to see all documents in a group. + // AccumulatorState needs to see all documents in a group. kAllDocuments, - // Accumulator only needs to see one document in a group, and when there is a sort order, that - // document must be the first document. + // AccumulatorState only needs to see one document in a group, and when there is a sort order, + // that document must be the first document. kFirstDocument, - // Accumulator only needs to see one document in a group, and when there is a sort order, that - // document must be the last document. + // AccumulatorState only needs to see one document in a group, and when there is a sort order, + // that document must be the last document. kLastDocument, }; -class Accumulator : public RefCountable { +class AccumulatorState : public RefCountable { public: - using Factory = std::function<boost::intrusive_ptr<Accumulator>()>; + using Factory = std::function<boost::intrusive_ptr<AccumulatorState>()>; - Accumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} + AccumulatorState(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} + + /** Marks the beginning of a new group. The input is the result of evaluating + * AccumulatorExpression::initializer, which can read from the group key. + */ + virtual void startNewGroup(const Value& input) {} /** Process input and update internal state. * merging should be true when processing outputs from getValue(true). @@ -89,7 +94,7 @@ public: return _memUsageBytes; } - /// Reset this accumulator to a fresh state ready to receive input. + /// Reset this accumulator to a fresh state, ready for a new call to startNewGroup. virtual void reset() = 0; virtual bool isAssociative() const { @@ -109,9 +114,19 @@ public: * * When executing on a sharded cluster, the result of this function will be sent to each * individual shard. + * + * This implementation assumes the accumulator has the simple syntax { <name>: <argument> }, + * such as { $sum: <argument> }. This syntax has no room for an initializer. Subclasses with a + * more elaborate syntax such should override this method. */ - virtual Document serialize(boost::intrusive_ptr<Expression> expression, bool explain) const { - return DOC(getOpName() << expression->serialize(explain)); + virtual Document serialize(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, + bool explain) const { + ExpressionConstant const* ec = dynamic_cast<ExpressionConstant const*>(initializer.get()); + invariant(ec); + invariant(ec->getValue().nullish()); + + return DOC(getOpName() << argument->serialize(explain)); } virtual AccumulatorDocumentsNeeded documentsNeeded() const { @@ -133,20 +148,7 @@ private: boost::intrusive_ptr<ExpressionContext> _expCtx; }; -/** - * A default parser for any accumulator that only takes a single expression as an argument. Returns - * the expression to be evaluated by the accumulator and an Accumulator::Factory. - */ -template <class AccName> -std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory> -genericParseSingleExpressionAccumulator(boost::intrusive_ptr<ExpressionContext> expCtx, - BSONElement elem, - VariablesParseState vps) { - auto exprValue = Expression::parseOperand(expCtx, elem, vps); - return {exprValue, [expCtx]() { return AccName::create(expCtx); }}; -} - -class AccumulatorAddToSet final : public Accumulator { +class AccumulatorAddToSet final : public AccumulatorState { public: /** * Creates a new $addToSet accumulator. If no memory limit is given, defaults to the value of @@ -160,7 +162,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { @@ -176,7 +178,7 @@ private: int _maxMemUsageBytes; }; -class AccumulatorFirst final : public Accumulator { +class AccumulatorFirst final : public AccumulatorState { public: explicit AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -185,7 +187,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); AccumulatorDocumentsNeeded documentsNeeded() const final { @@ -197,7 +199,7 @@ private: Value _first; }; -class AccumulatorLast final : public Accumulator { +class AccumulatorLast final : public AccumulatorState { public: explicit AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -206,7 +208,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); AccumulatorDocumentsNeeded documentsNeeded() const final { @@ -217,7 +219,7 @@ private: Value _last; }; -class AccumulatorSum final : public Accumulator { +class AccumulatorSum final : public AccumulatorState { public: explicit AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -226,7 +228,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { @@ -243,7 +245,7 @@ private: Decimal128 decimalTotal; }; -class AccumulatorMinMax : public Accumulator { +class AccumulatorMinMax : public AccumulatorState { public: enum Sense : int { MIN = 1, @@ -274,7 +276,7 @@ class AccumulatorMax final : public AccumulatorMinMax { public: explicit AccumulatorMax(const boost::intrusive_ptr<ExpressionContext>& expCtx) : AccumulatorMinMax(expCtx, MAX) {} - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); }; @@ -282,11 +284,11 @@ class AccumulatorMin final : public AccumulatorMinMax { public: explicit AccumulatorMin(const boost::intrusive_ptr<ExpressionContext>& expCtx) : AccumulatorMinMax(expCtx, MIN) {} - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); }; -class AccumulatorPush final : public Accumulator { +class AccumulatorPush final : public AccumulatorState { public: /** * Creates a new $push accumulator. If no memory limit is given, defaults to the value of the @@ -300,7 +302,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); private: @@ -308,7 +310,7 @@ private: int _maxMemUsageBytes; }; -class AccumulatorAvg final : public Accumulator { +class AccumulatorAvg final : public AccumulatorState { public: explicit AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -317,7 +319,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); private: @@ -333,7 +335,7 @@ private: long long _count; }; -class AccumulatorStdDev : public Accumulator { +class AccumulatorStdDev : public AccumulatorState { public: AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, bool isSamp); @@ -353,7 +355,7 @@ class AccumulatorStdDevPop final : public AccumulatorStdDev { public: explicit AccumulatorStdDevPop(const boost::intrusive_ptr<ExpressionContext>& expCtx) : AccumulatorStdDev(expCtx, false) {} - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); }; @@ -361,11 +363,11 @@ class AccumulatorStdDevSamp final : public AccumulatorStdDev { public: explicit AccumulatorStdDevSamp(const boost::intrusive_ptr<ExpressionContext>& expCtx) : AccumulatorStdDev(expCtx, true) {} - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); }; -class AccumulatorMergeObjects : public Accumulator { +class AccumulatorMergeObjects : public AccumulatorState { public: AccumulatorMergeObjects(const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -374,7 +376,7 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx); private: diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index 3a6a7b22944..7af05eda870 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -81,7 +81,7 @@ Value AccumulatorAddToSet::getValue(bool toBeMerged) { AccumulatorAddToSet::AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx, boost::optional<int> maxMemoryUsageBytes) - : Accumulator(expCtx), + : AccumulatorState(expCtx), _set(expCtx->getValueComparator().makeUnorderedValueSet()), _maxMemUsageBytes(maxMemoryUsageBytes.value_or(internalQueryMaxAddToSetBytes.load())) { _memUsageBytes = sizeof(*this); @@ -92,7 +92,7 @@ void AccumulatorAddToSet::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorAddToSet::create( +intrusive_ptr<AccumulatorState> AccumulatorAddToSet::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorAddToSet(expCtx, boost::none); } diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index 2efc2cee191..1e7f55f5ab3 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -93,7 +93,7 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) { _count++; } -intrusive_ptr<Accumulator> AccumulatorAvg::create( +intrusive_ptr<AccumulatorState> AccumulatorAvg::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorAvg(expCtx); } @@ -123,8 +123,8 @@ Value AccumulatorAvg::getValue(bool toBeMerged) { } AccumulatorAvg::AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx) - : Accumulator(expCtx), _isDecimal(false), _count(0) { - // This is a fixed size Accumulator so we never need to update this + : AccumulatorState(expCtx), _isDecimal(false), _count(0) { + // This is a fixed size AccumulatorState so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_first.cpp b/src/mongo/db/pipeline/accumulator_first.cpp index aed285cf5db..b48b9e2330e 100644 --- a/src/mongo/db/pipeline/accumulator_first.cpp +++ b/src/mongo/db/pipeline/accumulator_first.cpp @@ -59,7 +59,7 @@ Value AccumulatorFirst::getValue(bool toBeMerged) { } AccumulatorFirst::AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx) - : Accumulator(expCtx), _haveFirst(false) { + : AccumulatorState(expCtx), _haveFirst(false) { _memUsageBytes = sizeof(*this); } @@ -70,7 +70,7 @@ void AccumulatorFirst::reset() { } -intrusive_ptr<Accumulator> AccumulatorFirst::create( +intrusive_ptr<AccumulatorState> AccumulatorFirst::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorFirst(expCtx); } diff --git a/src/mongo/db/pipeline/accumulator_js_reduce.cpp b/src/mongo/db/pipeline/accumulator_js_reduce.cpp index 877ac1ca41e..41183c7ac40 100644 --- a/src/mongo/db/pipeline/accumulator_js_reduce.cpp +++ b/src/mongo/db/pipeline/accumulator_js_reduce.cpp @@ -37,10 +37,8 @@ namespace mongo { REGISTER_ACCUMULATOR(_internalJsReduce, AccumulatorInternalJsReduce::parseInternalJsReduce); -std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory> -AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<ExpressionContext> expCtx, - BSONElement elem, - VariablesParseState vps) { +AccumulationExpression AccumulatorInternalJsReduce::parseInternalJsReduce( + boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps) { uassert(31326, str::stream() << kAccumulatorName << " requires a document argument, but found " << elem.type(), @@ -48,13 +46,13 @@ AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<Expressi BSONObj obj = elem.embeddedObject(); std::string funcSource; - boost::intrusive_ptr<Expression> dataExpr; + boost::intrusive_ptr<Expression> argument; for (auto&& element : obj) { if (element.fieldNameStringData() == "eval") { funcSource = parseReduceFunction(element); } else if (element.fieldNameStringData() == "data") { - dataExpr = Expression::parseOperand(expCtx, element, vps); + argument = Expression::parseOperand(expCtx, element, vps); } else { uasserted(31243, str::stream() << "Invalid argument specified to " << kAccumulatorName << ": " @@ -68,13 +66,14 @@ AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<Expressi uassert(31349, str::stream() << kAccumulatorName << " requires 'data' argument, recieved input: " << obj.toString(false), - dataExpr); + argument); auto factory = [expCtx, funcSource = funcSource]() { return AccumulatorInternalJsReduce::create(expCtx, funcSource); }; - return {std::move(dataExpr), std::move(factory)}; + auto initializer = ExpressionConstant::create(expCtx, Value(BSONNULL)); + return {std::move(initializer), std::move(argument), std::move(factory)}; } std::string AccumulatorInternalJsReduce::parseReduceFunction(BSONElement func) { @@ -168,7 +167,7 @@ Value AccumulatorInternalJsReduce::getValue(bool toBeMerged) { } } -boost::intrusive_ptr<Accumulator> AccumulatorInternalJsReduce::create( +boost::intrusive_ptr<AccumulatorState> AccumulatorInternalJsReduce::create( const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData funcSource) { return make_intrusive<AccumulatorInternalJsReduce>(expCtx, funcSource); @@ -181,9 +180,233 @@ void AccumulatorInternalJsReduce::reset() { } // Returns this accumulator serialized as a Value along with the reduce function. -Document AccumulatorInternalJsReduce::serialize(boost::intrusive_ptr<Expression> expression, +Document AccumulatorInternalJsReduce::serialize(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, bool explain) const { - return DOC( - getOpName() << DOC("data" << expression->serialize(explain) << "eval" << _funcSource)); + return DOC(getOpName() << DOC("data" << argument->serialize(explain) << "eval" << _funcSource)); } + +REGISTER_ACCUMULATOR(accumulator, AccumulatorJs::parse); + +boost::intrusive_ptr<AccumulatorState> AccumulatorJs::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::string init, + std::string accumulate, + std::string merge, + std::string finalize) { + return new AccumulatorJs( + expCtx, std::move(init), std::move(accumulate), std::move(merge), std::move(finalize)); +} + +namespace { +// Parses a constant expression of type String or Code. +std::string parseFunction(StringData fieldName, + boost::intrusive_ptr<ExpressionContext> expCtx, + BSONElement elem, + VariablesParseState vps) { + boost::intrusive_ptr<Expression> expr = Expression::parseOperand(expCtx, elem, vps); + expr = expr->optimize(); + ExpressionConstant* ec = dynamic_cast<ExpressionConstant*>(expr.get()); + uassert(4544701, + str::stream() << "$accumulator '" << fieldName << "' must be a constant expression", + ec); + Value v = ec->getValue(); + uassert(4544702, + str::stream() << "$accumulator '" << fieldName << "' must be a String or Code", + v.getType() == BSONType::String || v.getType() == BSONType::Code); + return v.coerceToString(); +} +} // namespace + + +Document AccumulatorJs::serialize(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, + bool explain) const { + MutableDocument args; + args.addField("init", Value(_init)); + args.addField("initArgs", Value(initializer->serialize(explain))); + args.addField("accumulate", Value(_accumulate)); + args.addField("accumulateArgs", Value(argument->serialize(explain))); + args.addField("merge", Value(_merge)); + args.addField("finalize", Value(_finalize)); + args.addField("lang", Value("js"_sd)); + return DOC(getOpName() << args.freeze()); +} + +AccumulationExpression AccumulatorJs::parse(boost::intrusive_ptr<ExpressionContext> expCtx, + BSONElement elem, + VariablesParseState vps) { + /* + * {$accumulator: { + * init: <code>, + * accumulate: <code>, + * merge: <code>, + * finalize: <code>, + * + * accumulateArgs: <expr>, // evaluated once per document + * + * initArgs: <expr>, // evaluated once per group + * + * lang: 'js', + * }} + */ + uassert(4544703, + str::stream() << "$accumulator expects an object as an argument; found: " + << typeName(elem.type()), + elem.type() == BSONType::Object); + BSONObj obj = elem.embeddedObject(); + + std::string init, accumulate, merge, finalize; + boost::intrusive_ptr<Expression> initArgs, accumulateArgs; + + for (auto&& element : obj) { + auto name = element.fieldNameStringData(); + if (name == "init") { + init = parseFunction("init", expCtx, element, vps); + } else if (name == "accumulate") { + accumulate = parseFunction("accumulate", expCtx, element, vps); + } else if (name == "merge") { + merge = parseFunction("merge", expCtx, element, vps); + } else if (name == "finalize") { + finalize = parseFunction("finalize", expCtx, element, vps); + } else if (name == "initArgs") { + initArgs = Expression::parseOperand(expCtx, element, vps); + } else if (name == "accumulateArgs") { + accumulateArgs = Expression::parseOperand(expCtx, element, vps); + } else if (name == "lang") { + uassert(4544704, + str::stream() << "$accumulator lang must be a string; found: " + << element.type(), + element.type() == BSONType::String); + uassert(4544705, + "$accumulator only supports lang: 'js'", + element.valueStringData() == "js"); + } else { + // unexpected field + uassert( + 4544706, str::stream() << "$accumulator got an unexpected field: " << name, false); + } + } + uassert(4544707, "$accumulator missing required argument 'init'", !init.empty()); + uassert(4544708, "$accumulator missing required argument 'accumulate'", !accumulate.empty()); + uassert(4544709, "$accumulator missing required argument 'merge'", !merge.empty()); + if (finalize.empty()) { + // finalize is optional because many custom accumulators will return the final state + // unchanged. + finalize = "function(state) { return state; }"; + } + if (!initArgs) { + // initArgs is optional because most custom accumulators don't need the state to depend on + // the group key. + initArgs = ExpressionConstant::create(expCtx, Value(BSONArray())); + } + // accumulateArgs is required because it's the only way to communicate a value from the input + // stream into the accumulator state. + uassert(4544710, "$accumulator missing required argument 'accumulateArgs'", accumulateArgs); + + auto factory = [expCtx = expCtx, + init = std::move(init), + accumulate = std::move(accumulate), + merge = std::move(merge), + finalize = std::move(finalize)]() { + return new AccumulatorJs(expCtx, init, accumulate, merge, finalize); + }; + return {std::move(initArgs), std::move(accumulateArgs), std::move(factory)}; +} + +Value AccumulatorJs::getValue(bool toBeMerged) { + // _state is initialized when we encounter the first document in each group. We never create + // empty groups: even in a {$group: {_id: 1, ...}}, we will return zero groups rather than one + // empty group. + invariant(_state); + + // If toBeMerged then we return the current state, to be fed back in to accumulate / merge / + // finalize later. If not toBeMerged then we return the final value, by calling finalize. + if (toBeMerged) { + return *_state; + } + + // Get the final value given the current accumulator state. + + auto& expCtx = getExpressionContext(); + auto jsExec = expCtx->getJsExecWithScope(); + auto func = makeJsFunc(expCtx, _finalize); + + return jsExec->callFunction(func, BSON_ARRAY(*_state), {}); +} + +void AccumulatorJs::startNewGroup(Value const& input) { + // Between groups the _state should be empty: we initialize it to be empty it in the + // constructor, and we clear it at the end of each group (in .reset()). + invariant(!_state); + + auto& expCtx = getExpressionContext(); + auto jsExec = expCtx->getJsExecWithScope(); + auto func = makeJsFunc(expCtx, _init); + + // input is a value produced by our AccumulationExpression::initializer. + uassert(4544711, + str::stream() << "$accumulator initArgs must evaluate to an array: " + << input.toString(), + input.getType() == BSONType::Array); + + size_t index = 0; + BSONArrayBuilder bob; + for (auto&& arg : input.getArray()) { + arg.addToBsonArray(&bob, index++); + } + + _state = jsExec->callFunction(func, bob.arr(), {}); + + recomputeMemUsageBytes(); +} + +void AccumulatorJs::reset() { + _state = std::nullopt; + recomputeMemUsageBytes(); +} + +void AccumulatorJs::processInternal(const Value& input, bool merging) { + // _state should be nonempty because we populate it in startNewGroup. + invariant(_state); + + auto& expCtx = getExpressionContext(); + auto jsExec = expCtx->getJsExecWithScope(); + + if (merging) { + // input is an intermediate state from another instance of this kind of accumulator. Call + // the user's merge function. + auto func = makeJsFunc(expCtx, _merge); + _state = jsExec->callFunction(func, BSON_ARRAY(*_state << input), {}); + recomputeMemUsageBytes(); + } else { + // input is a value produced by our AccumulationExpression::argument. Call the user's + // accumulate function. + auto func = makeJsFunc(expCtx, _accumulate); + uassert(4544712, + str::stream() << "$accumulator accumulateArgs must evaluate to an array: " + << input.toString(), + input.getType() == BSONType::Array); + + size_t index = 0; + BSONArrayBuilder bob; + _state->addToBsonArray(&bob, index++); + for (auto&& arg : input.getArray()) { + arg.addToBsonArray(&bob, index++); + } + + _state = jsExec->callFunction(func, bob.done(), {}); + recomputeMemUsageBytes(); + } +} + +void AccumulatorJs::recomputeMemUsageBytes() { + auto stateSize = _state.value_or(Value{}).getApproximateSize(); + uassert(4544713, + str::stream() << "$accumulator state exceeded max BSON size: " << stateSize, + stateSize <= BSONObjMaxUserSize); + _memUsageBytes = sizeof(*this) + stateSize + _init.capacity() + _accumulate.capacity() + + _merge.capacity() + _finalize.capacity(); +} + } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_js_reduce.h b/src/mongo/db/pipeline/accumulator_js_reduce.h index bc2725a5567..1bb15f948e8 100644 --- a/src/mongo/db/pipeline/accumulator_js_reduce.h +++ b/src/mongo/db/pipeline/accumulator_js_reduce.h @@ -38,19 +38,19 @@ namespace mongo { -class AccumulatorInternalJsReduce final : public Accumulator { +class AccumulatorInternalJsReduce final : public AccumulatorState { public: static constexpr auto kAccumulatorName = "$_internalJsReduce"_sd; - static boost::intrusive_ptr<Accumulator> create( + static boost::intrusive_ptr<AccumulatorState> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData funcSource); - static std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory> parseInternalJsReduce( + static AccumulationExpression parseInternalJsReduce( boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps); AccumulatorInternalJsReduce(const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData funcSource) - : Accumulator(expCtx), _funcSource(funcSource) { + : AccumulatorState(expCtx), _funcSource(funcSource) { _memUsageBytes = sizeof(*this); } @@ -64,7 +64,8 @@ public: void reset() final; - virtual Document serialize(boost::intrusive_ptr<Expression> expression, + virtual Document serialize(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, bool explain) const override; private: @@ -75,4 +76,59 @@ private: Value _key; }; +class AccumulatorJs final : public AccumulatorState { +public: + static constexpr auto kAccumulatorName = "$accumulator"_sd; + const char* getOpName() const final { + return kAccumulatorName.rawData(); + } + + // An AccumulatorState instance only owns its "static" arguments: those that don't need to be + // evaluated per input document. + static boost::intrusive_ptr<AccumulatorState> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::string init, + std::string accumulate, + std::string merge, + std::string finalize); + + static AccumulationExpression parse(boost::intrusive_ptr<ExpressionContext> expCtx, + BSONElement elem, + VariablesParseState vps); + + Value getValue(bool toBeMerged) final; + void reset() final; + void processInternal(const Value& input, bool merging) final; + + Document serialize(boost::intrusive_ptr<Expression> initializer, + boost::intrusive_ptr<Expression> argument, + bool explain) const final; + void startNewGroup(Value const& input) final; + +private: + AccumulatorJs(const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::string init, + std::string accumulate, + std::string merge, + std::string finalize) + : AccumulatorState(expCtx), + _init(init), + _accumulate(accumulate), + _merge(merge), + _finalize(finalize) { + recomputeMemUsageBytes(); + } + void recomputeMemUsageBytes(); + + // static arguments + std::string _init, _accumulate, _merge, _finalize; + + // accumulator state during execution + // - When the accumulator is first created, _state is empty. + // - When the accumulator is fed its first input Value, it runs the user init and accumulate + // functions, and _state gets a Value. + // - When the accumulator is reset, _state becomes empty again. + std::optional<Value> _state; +}; + } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_last.cpp b/src/mongo/db/pipeline/accumulator_last.cpp index 150360d4fdd..14362e42cab 100644 --- a/src/mongo/db/pipeline/accumulator_last.cpp +++ b/src/mongo/db/pipeline/accumulator_last.cpp @@ -55,7 +55,7 @@ Value AccumulatorLast::getValue(bool toBeMerged) { } AccumulatorLast::AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx) - : Accumulator(expCtx) { + : AccumulatorState(expCtx) { _memUsageBytes = sizeof(*this); } @@ -64,7 +64,7 @@ void AccumulatorLast::reset() { _last = Value(); } -intrusive_ptr<Accumulator> AccumulatorLast::create( +intrusive_ptr<AccumulatorState> AccumulatorLast::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorLast(expCtx); } diff --git a/src/mongo/db/pipeline/accumulator_merge_objects.cpp b/src/mongo/db/pipeline/accumulator_merge_objects.cpp index d1e6310ea23..6b23ae528a1 100644 --- a/src/mongo/db/pipeline/accumulator_merge_objects.cpp +++ b/src/mongo/db/pipeline/accumulator_merge_objects.cpp @@ -49,14 +49,14 @@ const char* AccumulatorMergeObjects::getOpName() const { return "$mergeObjects"; } -intrusive_ptr<Accumulator> AccumulatorMergeObjects::create( +intrusive_ptr<AccumulatorState> AccumulatorMergeObjects::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorMergeObjects(expCtx); } AccumulatorMergeObjects::AccumulatorMergeObjects( const boost::intrusive_ptr<ExpressionContext>& expCtx) - : Accumulator(expCtx) { + : AccumulatorState(expCtx) { _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index c1af0ff8226..265a84b7075 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -71,7 +71,7 @@ Value AccumulatorMinMax::getValue(bool toBeMerged) { AccumulatorMinMax::AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx, Sense sense) - : Accumulator(expCtx), _sense(sense) { + : AccumulatorState(expCtx), _sense(sense) { _memUsageBytes = sizeof(*this); } @@ -80,12 +80,12 @@ void AccumulatorMinMax::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorMin::create( +intrusive_ptr<AccumulatorState> AccumulatorMin::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorMin(expCtx); } -intrusive_ptr<Accumulator> AccumulatorMax::create( +intrusive_ptr<AccumulatorState> AccumulatorMax::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorMax(expCtx); } diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp index d13a942ef62..3a004dd9731 100644 --- a/src/mongo/db/pipeline/accumulator_push.cpp +++ b/src/mongo/db/pipeline/accumulator_push.cpp @@ -83,7 +83,7 @@ Value AccumulatorPush::getValue(bool toBeMerged) { AccumulatorPush::AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx, boost::optional<int> maxMemoryUsageBytes) - : Accumulator(expCtx), + : AccumulatorState(expCtx), _maxMemUsageBytes(maxMemoryUsageBytes.value_or(internalQueryMaxPushBytes.load())) { _memUsageBytes = sizeof(*this); } @@ -93,7 +93,7 @@ void AccumulatorPush::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorPush::create( +intrusive_ptr<AccumulatorState> AccumulatorPush::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorPush(expCtx, boost::none); } diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 55367d766be..cdd31b2c897 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -96,20 +96,20 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) { } } -intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create( +intrusive_ptr<AccumulatorState> AccumulatorStdDevSamp::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorStdDevSamp(expCtx); } -intrusive_ptr<Accumulator> AccumulatorStdDevPop::create( +intrusive_ptr<AccumulatorState> AccumulatorStdDevPop::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorStdDevPop(expCtx); } AccumulatorStdDev::AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, bool isSamp) - : Accumulator(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) { - // This is a fixed size Accumulator so we never need to update this + : AccumulatorState(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) { + // This is a fixed size AccumulatorState so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 6cd34d8c76f..182f592ce3f 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -85,7 +85,7 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) { } } -intrusive_ptr<Accumulator> AccumulatorSum::create( +intrusive_ptr<AccumulatorState> AccumulatorSum::create( const boost::intrusive_ptr<ExpressionContext>& expCtx) { return new AccumulatorSum(expCtx); } @@ -128,8 +128,8 @@ Value AccumulatorSum::getValue(bool toBeMerged) { } AccumulatorSum::AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx) - : Accumulator(expCtx) { - // This is a fixed size Accumulator so we never need to update this. + : AccumulatorState(expCtx) { + // This is a fixed size AccumulatorState so we never need to update this. _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 373a24a4155..724e6a6838a 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -46,9 +46,9 @@ using std::numeric_limits; using std::string; /** - * Takes the name of an Accumulator as its template argument and a list of pairs of arguments and - * expected results as its second argument, and asserts that for the given Accumulator the arguments - * evaluate to the expected results. + * Takes the name of an AccumulatorState as its template argument and a list of pairs of arguments + * and expected results as its second argument, and asserts that for the given AccumulatorState the + * arguments evaluate to the expected results. */ template <typename AccName> static void assertExpectedResults( diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.cpp b/src/mongo/db/pipeline/document_source_bucket_auto.cpp index eba38c5c58b..c6decbca69c 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto.cpp @@ -109,13 +109,25 @@ DocumentSource::GetNextResult DocumentSourceBucketAuto::doGetNext() { return makeDocument(*(_bucketsIterator++)); } +boost::intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::optimize() { + _groupByExpression = _groupByExpression->optimize(); + for (auto&& accumulatedField : _accumulatedFields) { + accumulatedField.expr.argument = accumulatedField.expr.argument->optimize(); + accumulatedField.expr.initializer = accumulatedField.expr.initializer->optimize(); + } + return this; +} + DepsTracker::State DocumentSourceBucketAuto::getDependencies(DepsTracker* deps) const { // Add the 'groupBy' expression. _groupByExpression->addDependencies(deps); // Add the 'output' fields. for (auto&& accumulatedField : _accumulatedFields) { - accumulatedField.expression->addDependencies(deps); + // Anything the per-doc expression depends on, the whole stage depends on. + accumulatedField.expr.argument->addDependencies(deps); + // The initializer should be an ExpressionConstant, or something that optimizes to one. + // ExpressionConstant doesn't have dependencies. } // We know exactly which fields will be present in the output document. Future stages cannot @@ -189,7 +201,8 @@ void DocumentSourceBucketAuto::addDocumentToBucket(const pair<Value, Document>& const size_t numAccumulators = _accumulatedFields.size(); for (size_t k = 0; k < numAccumulators; k++) { bucket._accums[k]->process( - _accumulatedFields[k].expression->evaluate(entry.second, &pExpCtx->variables), false); + _accumulatedFields[k].expr.argument->evaluate(entry.second, &pExpCtx->variables), + false); } } @@ -234,6 +247,16 @@ void DocumentSourceBucketAuto::populateBuckets() { // Initialize the current bucket. Bucket currentBucket(pExpCtx, currentValue.first, currentValue.first, _accumulatedFields); + // Evaluate each initializer against an empty document. Normally the + // initializer can refer to the group key, but in $bucketAuto there is no single + // group key per bucket. + Document emptyDoc; + for (size_t k = 0; k < _accumulatedFields.size(); ++k) { + Value initializerValue = + _accumulatedFields[k].expr.initializer->evaluate(emptyDoc, &pExpCtx->variables); + currentBucket._accums[k]->startNewGroup(initializerValue); + } + // Add the first value into the current bucket. addDocumentToBucket(currentValue, currentBucket); @@ -382,10 +405,11 @@ Value DocumentSourceBucketAuto::serialize( MutableDocument outputSpec(_accumulatedFields.size()); for (auto&& accumulatedField : _accumulatedFields) { - intrusive_ptr<Accumulator> accum = accumulatedField.makeAccumulator(); + intrusive_ptr<AccumulatorState> accum = accumulatedField.makeAccumulator(); outputSpec[accumulatedField.fieldName] = - Value{Document{{accum->getOpName(), - accumulatedField.expression->serialize(static_cast<bool>(explain))}}}; + Value(accum->serialize(accumulatedField.expr.initializer, + accumulatedField.expr.argument, + static_cast<bool>(explain))); } insides["output"] = outputSpec.freezeToValue(); @@ -405,9 +429,11 @@ intrusive_ptr<DocumentSourceBucketAuto> DocumentSourceBucketAuto::create( numBuckets > 0); // If there is no output field specified, then add the default one. if (accumulationStatements.empty()) { - accumulationStatements.emplace_back("count", - ExpressionConstant::create(pExpCtx, Value(1)), - [pExpCtx] { return AccumulatorSum::create(pExpCtx); }); + accumulationStatements.emplace_back( + "count", + AccumulationExpression(ExpressionConstant::create(pExpCtx, Value(BSONNULL)), + ExpressionConstant::create(pExpCtx, Value(1)), + [pExpCtx] { return AccumulatorSum::create(pExpCtx); })); } return new DocumentSourceBucketAuto(pExpCtx, groupByExpression, @@ -486,8 +512,13 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( argument.type() == BSONType::Object); for (auto&& outputField : argument.embeddedObject()) { - accumulationStatements.push_back( - AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps)); + auto stmt = + AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps); + stmt.expr.initializer = stmt.expr.initializer->optimize(); + uassert(4544714, + "Can't refer to the group key in $bucketAuto", + ExpressionConstant::isNullOrConstant(stmt.expr.initializer)); + accumulationStatements.push_back(std::move(stmt)); } } else if ("granularity" == argName) { uassert(40261, diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.h b/src/mongo/db/pipeline/document_source_bucket_auto.h index 8804b0df6c1..fcc35a130d3 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.h +++ b/src/mongo/db/pipeline/document_source_bucket_auto.h @@ -49,6 +49,7 @@ public: DepsTracker::State getDependencies(DepsTracker* deps) const final; const char* getSourceName() const final; + boost::intrusive_ptr<DocumentSource> optimize() final; StageConstraints constraints(Pipeline::SplitState pipeState) const final { return {StreamType::kBlocking, @@ -114,7 +115,7 @@ private: const std::vector<AccumulationStatement>& accumulationStatements); Value _min; Value _max; - std::vector<boost::intrusive_ptr<Accumulator>> _accums; + std::vector<boost::intrusive_ptr<AccumulatorState>> _accums; }; /** diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp index 0f961f8c363..8268f0a1dd1 100644 --- a/src/mongo/db/pipeline/document_source_group.cpp +++ b/src/mongo/db/pipeline/document_source_group.cpp @@ -158,6 +158,17 @@ DocumentSource::GetNextResult DocumentSourceGroup::getNextSpilled() { _currentId = _firstPartOfNextGroup.first; const size_t numAccumulators = _accumulatedFields.size(); + + // Call startNewGroup on every accumulator. + Value expandedId = expandId(_currentId); + Document idDoc = + expandedId.getType() == BSONType::Object ? expandedId.getDocument() : Document(); + for (size_t i = 0; i < numAccumulators; ++i) { + Value initializerValue = + _accumulatedFields[i].expr.initializer->evaluate(idDoc, &pExpCtx->variables); + _currentAccumulators[i]->startNewGroup(initializerValue); + } + while (pExpCtx->getValueComparator().evaluate(_currentId == _firstPartOfNextGroup.first)) { // Inside of this loop, _firstPartOfNextGroup is the current data being processed. // At loop exit, it is the first value to be processed in the next group. @@ -216,7 +227,8 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() { } for (auto&& accumulatedField : _accumulatedFields) { - accumulatedField.expression = accumulatedField.expression->optimize(); + accumulatedField.expr.initializer = accumulatedField.expr.initializer->optimize(); + accumulatedField.expr.argument = accumulatedField.expr.argument->optimize(); } return this; @@ -241,9 +253,11 @@ Value DocumentSourceGroup::serialize(boost::optional<ExplainOptions::Verbosity> // Add the remaining fields. for (auto&& accumulatedField : _accumulatedFields) { - intrusive_ptr<Accumulator> accum = accumulatedField.makeAccumulator(); + intrusive_ptr<AccumulatorState> accum = accumulatedField.makeAccumulator(); insides[accumulatedField.fieldName] = - Value(accum->serialize(accumulatedField.expression, static_cast<bool>(explain))); + Value(accum->serialize(accumulatedField.expr.initializer, + accumulatedField.expr.argument, + static_cast<bool>(explain))); } if (_doingMerge) { @@ -263,7 +277,8 @@ DepsTracker::State DocumentSourceGroup::getDependencies(DepsTracker* deps) const // add the rest for (auto&& accumulatedField : _accumulatedFields) { - accumulatedField.expression->addDependencies(deps); + accumulatedField.expr.argument->addDependencies(deps); + // Don't add initializer, because it doesn't refer to docs from the input stream. } return DepsTracker::State::EXHAUSTIVE_ALL; @@ -485,16 +500,23 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // accumulator. This is done in a somewhat odd way in order to avoid hashing 'id' and // looking it up in '_groups' multiple times. const size_t oldSize = _groups->size(); - vector<intrusive_ptr<Accumulator>>& group = (*_groups)[id]; + vector<intrusive_ptr<AccumulatorState>>& group = (*_groups)[id]; const bool inserted = _groups->size() != oldSize; if (inserted) { _memoryUsageBytes += id.getApproximateSize(); - // Add the accumulators + // Initialize and add the accumulators + Value expandedId = expandId(id); + Document idDoc = + expandedId.getType() == BSONType::Object ? expandedId.getDocument() : Document(); group.reserve(numAccumulators); for (auto&& accumulatedField : _accumulatedFields) { - group.push_back(accumulatedField.makeAccumulator()); + auto accum = accumulatedField.makeAccumulator(); + Value initializerValue = + accumulatedField.expr.initializer->evaluate(idDoc, &pExpCtx->variables); + accum->startNewGroup(initializerValue); + group.push_back(accum); } } else { for (auto&& groupObj : group) { @@ -508,7 +530,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { for (size_t i = 0; i < numAccumulators; i++) { group[i]->process( - _accumulatedFields[i].expression->evaluate(rootDocument, &pExpCtx->variables), + _accumulatedFields[i].expr.argument->evaluate(rootDocument, &pExpCtx->variables), _doingMerge); _memoryUsageBytes += group[i]->memUsageForSorter(); @@ -693,7 +715,7 @@ boost::optional<DocumentSource::DistributedPlanLogic> DocumentSourceGroup::distr // original accumulator may be collecting an expression based on a field expression or // constant. Here, we accumulate the output of the same name from the prior group. auto copiedAccumulatedField = accumulatedField; - copiedAccumulatedField.expression = + copiedAccumulatedField.expr.argument = ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + copiedAccumulatedField.fieldName, vps); mergingGroup->addAccumulator(copiedAccumulatedField); } @@ -775,7 +797,10 @@ DocumentSourceGroup::rewriteGroupAsTransformOnFirstDocument() const { fields.push_back(std::make_pair("_id", ExpressionFieldPath::create(pExpCtx, groupId))); for (auto&& accumulator : _accumulatedFields) { - fields.push_back(std::make_pair(accumulator.fieldName, accumulator.expression)); + fields.push_back(std::make_pair(accumulator.fieldName, accumulator.expr.argument)); + + // Since we don't attempt this transformation for non-$first accumulators, + // the initializer should always be trivial. } return GroupFromFirstDocumentTransformation::create(pExpCtx, groupId, std::move(fields)); diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h index 88b0a669b6e..5d885e84a1f 100644 --- a/src/mongo/db/pipeline/document_source_group.h +++ b/src/mongo/db/pipeline/document_source_group.h @@ -88,7 +88,7 @@ private: class DocumentSourceGroup final : public DocumentSource { public: - using Accumulators = std::vector<boost::intrusive_ptr<Accumulator>>; + using Accumulators = std::vector<boost::intrusive_ptr<AccumulatorState>>; using GroupsMap = ValueUnorderedMap<Accumulators>; static constexpr StringData kStageName = "$group"_sd; diff --git a/src/mongo/db/pipeline/document_source_group_test.cpp b/src/mongo/db/pipeline/document_source_group_test.cpp index a2ed4358248..c3c764c4f6a 100644 --- a/src/mongo/db/pipeline/document_source_group_test.cpp +++ b/src/mongo/db/pipeline/document_source_group_test.cpp @@ -75,9 +75,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoading) { // This is the only way to do this in a debug build. auto&& parser = AccumulationStatement::getParser("$sum"); auto accumulatorArg = BSON("" << 1); - auto [expression, factory] = - parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); - AccumulationStatement countStatement{"count", expression, factory}; + auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); + AccumulationStatement countStatement{"count", accExpr}; auto group = DocumentSourceGroup::create( expCtx, ExpressionConstant::create(expCtx, Value(BSONNULL)), {countStatement}); auto mock = @@ -113,9 +112,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoadingWhileSpilled) { auto&& parser = AccumulationStatement::getParser("$push"); auto accumulatorArg = BSON("" << "$largeStr"); - auto [expression, factory] = - parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); - AccumulationStatement pushStatement{"spaceHog", expression, factory}; + auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); + AccumulationStatement pushStatement{"spaceHog", accExpr}; auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState); auto group = DocumentSourceGroup::create( @@ -156,9 +154,8 @@ TEST_F(DocumentSourceGroupTest, ShouldErrorIfNotAllowedToSpillToDiskAndResultSet auto&& parser = AccumulationStatement::getParser("$push"); auto accumulatorArg = BSON("" << "$largeStr"); - auto [expression, factory] = - parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); - AccumulationStatement pushStatement{"spaceHog", expression, factory}; + auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); + AccumulationStatement pushStatement{"spaceHog", accExpr}; auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState); auto group = DocumentSourceGroup::create( @@ -181,9 +178,8 @@ TEST_F(DocumentSourceGroupTest, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { auto&& parser = AccumulationStatement::getParser("$push"); auto accumulatorArg = BSON("" << "$largeStr"); - auto [expression, factory] = - parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); - AccumulationStatement pushStatement{"spaceHog", expression, factory}; + auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState); + AccumulationStatement pushStatement{"spaceHog", accExpr}; auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState); auto group = DocumentSourceGroup::create( diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index b05b4dcd918..4f772ca2363 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -399,15 +399,15 @@ public: * Used to make Accumulators available as Expressions, e.g., to make $sum available as an Expression * use "REGISTER_EXPRESSION(sum, ExpressionAccumulator<AccumulatorSum>::parse);". */ -template <typename Accumulator> +template <typename AccumulatorState> class ExpressionFromAccumulator - : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> { + : public ExpressionVariadic<ExpressionFromAccumulator<AccumulatorState>> { public: explicit ExpressionFromAccumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) - : ExpressionVariadic<ExpressionFromAccumulator<Accumulator>>(expCtx) {} + : ExpressionVariadic<ExpressionFromAccumulator<AccumulatorState>>(expCtx) {} Value evaluate(const Document& root, Variables* variables) const final { - Accumulator accum(this->getExpressionContext()); + AccumulatorState accum(this->getExpressionContext()); const auto n = this->_children.size(); // If a single array arg is given, loop through it passing each member to the accumulator. // If a single, non-array arg is given, pass it directly to the accumulator. @@ -435,15 +435,15 @@ public: if (this->_children.size() == 1) { return false; } - return Accumulator(this->getExpressionContext()).isAssociative(); + return AccumulatorState(this->getExpressionContext()).isAssociative(); } bool isCommutative() const final { - return Accumulator(this->getExpressionContext()).isCommutative(); + return AccumulatorState(this->getExpressionContext()).isCommutative(); } const char* getOpName() const final { - return Accumulator(this->getExpressionContext()).getOpName(); + return AccumulatorState(this->getExpressionContext()).getOpName(); } void acceptVisitor(ExpressionVisitor* visitor) final { diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h index 326c8e2e6ce..393778e5f21 100644 --- a/src/mongo/db/pipeline/expression_visitor.h +++ b/src/mongo/db/pipeline/expression_visitor.h @@ -158,7 +158,7 @@ class AccumulatorStdDevPop; class AccumulatorStdDevSamp; class AccumulatorSum; class AccumulatorMergeObjects; -template <typename Accumulator> +template <typename AccumulatorState> class ExpressionFromAccumulator; /** |