summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline
diff options
context:
space:
mode:
authorDavid Percy <david.percy@mongodb.com>2020-01-17 16:20:06 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2020-02-27 20:44:41 +0000
commit606fbf8eac896b0b4ed26e921b7f6bf1f73f5511 (patch)
tree4855ab6890e429ff79ffdf867d2b973361b62b00 /src/mongo/db/pipeline
parent5e57c0b0f7505035c37179d100fdd43ef2b6cc36 (diff)
downloadmongo-606fbf8eac896b0b4ed26e921b7f6bf1f73f5511.tar.gz
SERVER-45447 Add $accumulator for user-defined Javascript accumulators
Diffstat (limited to 'src/mongo/db/pipeline')
-rw-r--r--src/mongo/db/pipeline/accumulation_statement.cpp9
-rw-r--r--src/mongo/db/pipeline/accumulation_statement.h117
-rw-r--r--src/mongo/db/pipeline/accumulator.h90
-rw-r--r--src/mongo/db/pipeline/accumulator_add_to_set.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_avg.cpp6
-rw-r--r--src/mongo/db/pipeline/accumulator_first.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_js_reduce.cpp247
-rw-r--r--src/mongo/db/pipeline/accumulator_js_reduce.h66
-rw-r--r--src/mongo/db/pipeline/accumulator_last.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_merge_objects.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_min_max.cpp6
-rw-r--r--src/mongo/db/pipeline/accumulator_push.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_std_dev.cpp8
-rw-r--r--src/mongo/db/pipeline/accumulator_sum.cpp6
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp6
-rw-r--r--src/mongo/db/pipeline/document_source_bucket_auto.cpp51
-rw-r--r--src/mongo/db/pipeline/document_source_bucket_auto.h3
-rw-r--r--src/mongo/db/pipeline/document_source_group.cpp45
-rw-r--r--src/mongo/db/pipeline/document_source_group.h2
-rw-r--r--src/mongo/db/pipeline/document_source_group_test.cpp20
-rw-r--r--src/mongo/db/pipeline/expression.h14
-rw-r--r--src/mongo/db/pipeline/expression_visitor.h2
22 files changed, 562 insertions, 156 deletions
diff --git a/src/mongo/db/pipeline/accumulation_statement.cpp b/src/mongo/db/pipeline/accumulation_statement.cpp
index 18595477227..72b2b80a6c6 100644
--- a/src/mongo/db/pipeline/accumulation_statement.cpp
+++ b/src/mongo/db/pipeline/accumulation_statement.cpp
@@ -65,8 +65,8 @@ AccumulationStatement::Parser& AccumulationStatement::getParser(StringData name)
return it->second;
}
-boost::intrusive_ptr<Accumulator> AccumulationStatement::makeAccumulator() const {
- return _factory();
+boost::intrusive_ptr<AccumulatorState> AccumulationStatement::makeAccumulator() const {
+ return expr.factory();
}
AccumulationStatement AccumulationStatement::parseAccumulationStatement(
@@ -98,9 +98,10 @@ AccumulationStatement AccumulationStatement::parseAccumulationStatement(
specElem.type() != BSONType::Array);
auto&& parser = AccumulationStatement::getParser(accName);
- auto [expression, factory] = parser(expCtx, specElem, vps);
+ auto [initializer, argument, factory] = parser(expCtx, specElem, vps);
- return AccumulationStatement(fieldName.toString(), expression, factory);
+ return AccumulationStatement(fieldName.toString(),
+ AccumulationExpression(initializer, argument, factory));
}
} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h
index 91ac3a1aff3..fff666141d3 100644
--- a/src/mongo/db/pipeline/accumulation_statement.h
+++ b/src/mongo/db/pipeline/accumulation_statement.h
@@ -38,8 +38,8 @@
namespace mongo {
/**
- * Registers an Accumulator to have the name 'key'. When an accumulator with name '$key' is found
- * during parsing, 'factory' will be called to construct the Accumulator.
+ * Registers an AccumulatorState to have the name 'key'. When an accumulator with name '$key' is
+ * found during parsing, 'factory' will be called to construct the AccumulatorState.
*
* As an example, if your accumulator looks like {"$foo": <args>}, with a factory method 'create',
* you would add this line:
@@ -52,26 +52,102 @@ namespace mongo {
}
/**
+ * AccumulatorExpression represents the right-hand side of an AccumulationStatement. Note this is
+ * different from Expression; they are different nonterminals in the grammar.
+ *
+ * For example, in
+ * {$group: {
+ * _id: 1,
+ * count: {$sum: {$size: "$tags"}}
+ * }}
+ *
+ * we would say:
+ * The AccumulationStatement is count: {$sum: {$size: "$tags"}}
+ * The AccumulationExpression is {$sum: {$size: "$tags"}}
+ * The AccumulatorState::Factory is $sum
+ * The argument Expression is {$size: "$tags"}
+ * There is no initializer Expression.
+ *
+ * "$sum" corresponds to an AccumulatorState::Factory rather than AccumulatorState because
+ * AccumulatorState is an execution concept, not an AST concept: each instance of AccumulatorState
+ * contains intermediate values being accumulated.
+ *
+ * Like most accumulators, $sum does not require or accept an initializer Expression. At time of
+ * writing, only user-defined accumulators accept an initializer.
+ *
+ * For example, in:
+ * {$group: {
+ * _id: {cc: "$country_code"},
+ * top_stories: {$accumulator: {
+ * init: function(cc) { ... },
+ * initArgs: ["$cc"],
+ * accumulate: function(state, title, upvotes) { ... },
+ * accumulateArgs: ["$title", "$upvotes"],
+ * merge: function(state1, state2) { ... },
+ * lang: "js",
+ * }}
+ * }}
+ *
+ * we would say:
+ * The AccumulationStatement is top_stories: {$accumulator: ... }
+ * The AccumulationExpression is {$accumulator: ... }
+ * The argument Expression is ["$cc"]
+ * The initializer Expression is ["$title", "$upvotes"]
+ * The AccumulatorState::Factory holds all the other arguments to $accumulator.
+ *
+ */
+struct AccumulationExpression {
+ AccumulationExpression(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ AccumulatorState::Factory factory)
+ : initializer(initializer), argument(argument), factory(factory) {
+ invariant(this->initializer);
+ invariant(this->argument);
+ }
+
+ // The expression to use to obtain the input to the accumulator.
+ boost::intrusive_ptr<Expression> initializer;
+
+ // An expression evaluated once per input document, and passed to AccumulatorState::process.
+ boost::intrusive_ptr<Expression> argument;
+
+ // Constructs an AccumulatorState to do actual accumulation.
+ boost::intrusive_ptr<AccumulatorState> makeAccumulator() const;
+
+ // A no argument function object that can be called to create an AccumulatorState.
+ const AccumulatorState::Factory factory;
+};
+
+/**
+ * A default parser for any accumulator that only takes a single expression as an argument. Returns
+ * the expression to be evaluated by the accumulator and an AccumulatorState::Factory.
+ */
+template <class AccName>
+AccumulationExpression genericParseSingleExpressionAccumulator(
+ boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps) {
+ auto initializer = ExpressionConstant::create(expCtx, Value(BSONNULL));
+ auto argument = Expression::parseOperand(expCtx, elem, vps);
+ return {initializer, argument, [expCtx]() { return AccName::create(expCtx); }};
+}
+
+/**
* A class representing a user-specified accumulation, including the field name to put the
* accumulated result in, which accumulator to use, and the expression used to obtain the input to
- * the Accumulator.
+ * the AccumulatorState.
*/
class AccumulationStatement {
public:
- using Parser = std::function<std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory>(
+ using Parser = std::function<AccumulationExpression(
boost::intrusive_ptr<ExpressionContext>, BSONElement, VariablesParseState)>;
- AccumulationStatement(std::string fieldName,
- boost::intrusive_ptr<Expression> expression,
- Accumulator::Factory factory)
- : fieldName(std::move(fieldName)),
- expression(std::move(expression)),
- _factory(std::move(factory)) {}
+
+ AccumulationStatement(std::string fieldName, AccumulationExpression expr)
+ : fieldName(std::move(fieldName)), expr(std::move(expr)) {}
/**
* Parses a BSONElement that is an accumulated field, and returns an AccumulationStatement for
* that accumulated field.
*
- * Throws a AssertionException if parsing fails.
+ * Throws an AssertionException if parsing fails.
*/
static AccumulationStatement parseAccumulationStatement(
const boost::intrusive_ptr<ExpressionContext>& expCtx,
@@ -79,9 +155,9 @@ public:
const VariablesParseState& vps);
/**
- * Registers an Accumulator with a parsing function, so that when an accumulator with the given
- * name is encountered during parsing, we will know to call 'factory' to construct that
- * Accumulator.
+ * Registers an AccumulatorState with a parsing function, so that when an accumulator with the
+ * given name is encountered during parsing, we will know to call 'factory' to construct that
+ * AccumulatorState.
*
* DO NOT call this method directly. Instead, use the REGISTER_ACCUMULATOR macro defined in this
* file.
@@ -90,22 +166,17 @@ public:
/**
* Retrieves the Parser for the accumulator specified by the given name, and raises an error if
- * there is no such Accumulator registered.
+ * there is no such AccumulatorState registered.
*/
static Parser& getParser(StringData name);
// The field name is used to store the results of the accumulation in a result document.
std::string fieldName;
- // The expression to use to obtain the input to the accumulator.
- boost::intrusive_ptr<Expression> expression;
-
- // Constructs an Accumulator to do actual accumulation.
- boost::intrusive_ptr<Accumulator> makeAccumulator() const;
+ AccumulationExpression expr;
-private:
- // A no argument function object that can be called to create an Accumulator.
- const Accumulator::Factory _factory;
+ // Constructs an AccumulatorState to do actual accumulation.
+ boost::intrusive_ptr<AccumulatorState> makeAccumulator() const;
};
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index cb58bc7a3b1..e59ccb543da 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -51,23 +51,28 @@ namespace mongo {
* This enum indicates which documents an accumulator needs to see in order to compute its output.
*/
enum class AccumulatorDocumentsNeeded {
- // Accumulator needs to see all documents in a group.
+ // AccumulatorState needs to see all documents in a group.
kAllDocuments,
- // Accumulator only needs to see one document in a group, and when there is a sort order, that
- // document must be the first document.
+ // AccumulatorState only needs to see one document in a group, and when there is a sort order,
+ // that document must be the first document.
kFirstDocument,
- // Accumulator only needs to see one document in a group, and when there is a sort order, that
- // document must be the last document.
+ // AccumulatorState only needs to see one document in a group, and when there is a sort order,
+ // that document must be the last document.
kLastDocument,
};
-class Accumulator : public RefCountable {
+class AccumulatorState : public RefCountable {
public:
- using Factory = std::function<boost::intrusive_ptr<Accumulator>()>;
+ using Factory = std::function<boost::intrusive_ptr<AccumulatorState>()>;
- Accumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {}
+ AccumulatorState(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {}
+
+ /** Marks the beginning of a new group. The input is the result of evaluating
+ * AccumulatorExpression::initializer, which can read from the group key.
+ */
+ virtual void startNewGroup(const Value& input) {}
/** Process input and update internal state.
* merging should be true when processing outputs from getValue(true).
@@ -89,7 +94,7 @@ public:
return _memUsageBytes;
}
- /// Reset this accumulator to a fresh state ready to receive input.
+ /// Reset this accumulator to a fresh state, ready for a new call to startNewGroup.
virtual void reset() = 0;
virtual bool isAssociative() const {
@@ -109,9 +114,19 @@ public:
*
* When executing on a sharded cluster, the result of this function will be sent to each
* individual shard.
+ *
+ * This implementation assumes the accumulator has the simple syntax { <name>: <argument> },
+ * such as { $sum: <argument> }. This syntax has no room for an initializer. Subclasses with a
+ * more elaborate syntax such should override this method.
*/
- virtual Document serialize(boost::intrusive_ptr<Expression> expression, bool explain) const {
- return DOC(getOpName() << expression->serialize(explain));
+ virtual Document serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ bool explain) const {
+ ExpressionConstant const* ec = dynamic_cast<ExpressionConstant const*>(initializer.get());
+ invariant(ec);
+ invariant(ec->getValue().nullish());
+
+ return DOC(getOpName() << argument->serialize(explain));
}
virtual AccumulatorDocumentsNeeded documentsNeeded() const {
@@ -133,20 +148,7 @@ private:
boost::intrusive_ptr<ExpressionContext> _expCtx;
};
-/**
- * A default parser for any accumulator that only takes a single expression as an argument. Returns
- * the expression to be evaluated by the accumulator and an Accumulator::Factory.
- */
-template <class AccName>
-std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory>
-genericParseSingleExpressionAccumulator(boost::intrusive_ptr<ExpressionContext> expCtx,
- BSONElement elem,
- VariablesParseState vps) {
- auto exprValue = Expression::parseOperand(expCtx, elem, vps);
- return {exprValue, [expCtx]() { return AccName::create(expCtx); }};
-}
-
-class AccumulatorAddToSet final : public Accumulator {
+class AccumulatorAddToSet final : public AccumulatorState {
public:
/**
* Creates a new $addToSet accumulator. If no memory limit is given, defaults to the value of
@@ -160,7 +162,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
bool isAssociative() const final {
@@ -176,7 +178,7 @@ private:
int _maxMemUsageBytes;
};
-class AccumulatorFirst final : public Accumulator {
+class AccumulatorFirst final : public AccumulatorState {
public:
explicit AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -185,7 +187,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
AccumulatorDocumentsNeeded documentsNeeded() const final {
@@ -197,7 +199,7 @@ private:
Value _first;
};
-class AccumulatorLast final : public Accumulator {
+class AccumulatorLast final : public AccumulatorState {
public:
explicit AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -206,7 +208,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
AccumulatorDocumentsNeeded documentsNeeded() const final {
@@ -217,7 +219,7 @@ private:
Value _last;
};
-class AccumulatorSum final : public Accumulator {
+class AccumulatorSum final : public AccumulatorState {
public:
explicit AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -226,7 +228,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
bool isAssociative() const final {
@@ -243,7 +245,7 @@ private:
Decimal128 decimalTotal;
};
-class AccumulatorMinMax : public Accumulator {
+class AccumulatorMinMax : public AccumulatorState {
public:
enum Sense : int {
MIN = 1,
@@ -274,7 +276,7 @@ class AccumulatorMax final : public AccumulatorMinMax {
public:
explicit AccumulatorMax(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: AccumulatorMinMax(expCtx, MAX) {}
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
};
@@ -282,11 +284,11 @@ class AccumulatorMin final : public AccumulatorMinMax {
public:
explicit AccumulatorMin(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: AccumulatorMinMax(expCtx, MIN) {}
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
};
-class AccumulatorPush final : public Accumulator {
+class AccumulatorPush final : public AccumulatorState {
public:
/**
* Creates a new $push accumulator. If no memory limit is given, defaults to the value of the
@@ -300,7 +302,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
private:
@@ -308,7 +310,7 @@ private:
int _maxMemUsageBytes;
};
-class AccumulatorAvg final : public Accumulator {
+class AccumulatorAvg final : public AccumulatorState {
public:
explicit AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -317,7 +319,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
private:
@@ -333,7 +335,7 @@ private:
long long _count;
};
-class AccumulatorStdDev : public Accumulator {
+class AccumulatorStdDev : public AccumulatorState {
public:
AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, bool isSamp);
@@ -353,7 +355,7 @@ class AccumulatorStdDevPop final : public AccumulatorStdDev {
public:
explicit AccumulatorStdDevPop(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: AccumulatorStdDev(expCtx, false) {}
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
};
@@ -361,11 +363,11 @@ class AccumulatorStdDevSamp final : public AccumulatorStdDev {
public:
explicit AccumulatorStdDevSamp(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: AccumulatorStdDev(expCtx, true) {}
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
};
-class AccumulatorMergeObjects : public Accumulator {
+class AccumulatorMergeObjects : public AccumulatorState {
public:
AccumulatorMergeObjects(const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -374,7 +376,7 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx);
private:
diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp
index 3a6a7b22944..7af05eda870 100644
--- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp
+++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp
@@ -81,7 +81,7 @@ Value AccumulatorAddToSet::getValue(bool toBeMerged) {
AccumulatorAddToSet::AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx,
boost::optional<int> maxMemoryUsageBytes)
- : Accumulator(expCtx),
+ : AccumulatorState(expCtx),
_set(expCtx->getValueComparator().makeUnorderedValueSet()),
_maxMemUsageBytes(maxMemoryUsageBytes.value_or(internalQueryMaxAddToSetBytes.load())) {
_memUsageBytes = sizeof(*this);
@@ -92,7 +92,7 @@ void AccumulatorAddToSet::reset() {
_memUsageBytes = sizeof(*this);
}
-intrusive_ptr<Accumulator> AccumulatorAddToSet::create(
+intrusive_ptr<AccumulatorState> AccumulatorAddToSet::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorAddToSet(expCtx, boost::none);
}
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp
index 2efc2cee191..1e7f55f5ab3 100644
--- a/src/mongo/db/pipeline/accumulator_avg.cpp
+++ b/src/mongo/db/pipeline/accumulator_avg.cpp
@@ -93,7 +93,7 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) {
_count++;
}
-intrusive_ptr<Accumulator> AccumulatorAvg::create(
+intrusive_ptr<AccumulatorState> AccumulatorAvg::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorAvg(expCtx);
}
@@ -123,8 +123,8 @@ Value AccumulatorAvg::getValue(bool toBeMerged) {
}
AccumulatorAvg::AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : Accumulator(expCtx), _isDecimal(false), _count(0) {
- // This is a fixed size Accumulator so we never need to update this
+ : AccumulatorState(expCtx), _isDecimal(false), _count(0) {
+ // This is a fixed size AccumulatorState so we never need to update this
_memUsageBytes = sizeof(*this);
}
diff --git a/src/mongo/db/pipeline/accumulator_first.cpp b/src/mongo/db/pipeline/accumulator_first.cpp
index aed285cf5db..b48b9e2330e 100644
--- a/src/mongo/db/pipeline/accumulator_first.cpp
+++ b/src/mongo/db/pipeline/accumulator_first.cpp
@@ -59,7 +59,7 @@ Value AccumulatorFirst::getValue(bool toBeMerged) {
}
AccumulatorFirst::AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : Accumulator(expCtx), _haveFirst(false) {
+ : AccumulatorState(expCtx), _haveFirst(false) {
_memUsageBytes = sizeof(*this);
}
@@ -70,7 +70,7 @@ void AccumulatorFirst::reset() {
}
-intrusive_ptr<Accumulator> AccumulatorFirst::create(
+intrusive_ptr<AccumulatorState> AccumulatorFirst::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorFirst(expCtx);
}
diff --git a/src/mongo/db/pipeline/accumulator_js_reduce.cpp b/src/mongo/db/pipeline/accumulator_js_reduce.cpp
index 877ac1ca41e..41183c7ac40 100644
--- a/src/mongo/db/pipeline/accumulator_js_reduce.cpp
+++ b/src/mongo/db/pipeline/accumulator_js_reduce.cpp
@@ -37,10 +37,8 @@ namespace mongo {
REGISTER_ACCUMULATOR(_internalJsReduce, AccumulatorInternalJsReduce::parseInternalJsReduce);
-std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory>
-AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<ExpressionContext> expCtx,
- BSONElement elem,
- VariablesParseState vps) {
+AccumulationExpression AccumulatorInternalJsReduce::parseInternalJsReduce(
+ boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps) {
uassert(31326,
str::stream() << kAccumulatorName << " requires a document argument, but found "
<< elem.type(),
@@ -48,13 +46,13 @@ AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<Expressi
BSONObj obj = elem.embeddedObject();
std::string funcSource;
- boost::intrusive_ptr<Expression> dataExpr;
+ boost::intrusive_ptr<Expression> argument;
for (auto&& element : obj) {
if (element.fieldNameStringData() == "eval") {
funcSource = parseReduceFunction(element);
} else if (element.fieldNameStringData() == "data") {
- dataExpr = Expression::parseOperand(expCtx, element, vps);
+ argument = Expression::parseOperand(expCtx, element, vps);
} else {
uasserted(31243,
str::stream() << "Invalid argument specified to " << kAccumulatorName << ": "
@@ -68,13 +66,14 @@ AccumulatorInternalJsReduce::parseInternalJsReduce(boost::intrusive_ptr<Expressi
uassert(31349,
str::stream() << kAccumulatorName
<< " requires 'data' argument, recieved input: " << obj.toString(false),
- dataExpr);
+ argument);
auto factory = [expCtx, funcSource = funcSource]() {
return AccumulatorInternalJsReduce::create(expCtx, funcSource);
};
- return {std::move(dataExpr), std::move(factory)};
+ auto initializer = ExpressionConstant::create(expCtx, Value(BSONNULL));
+ return {std::move(initializer), std::move(argument), std::move(factory)};
}
std::string AccumulatorInternalJsReduce::parseReduceFunction(BSONElement func) {
@@ -168,7 +167,7 @@ Value AccumulatorInternalJsReduce::getValue(bool toBeMerged) {
}
}
-boost::intrusive_ptr<Accumulator> AccumulatorInternalJsReduce::create(
+boost::intrusive_ptr<AccumulatorState> AccumulatorInternalJsReduce::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData funcSource) {
return make_intrusive<AccumulatorInternalJsReduce>(expCtx, funcSource);
@@ -181,9 +180,233 @@ void AccumulatorInternalJsReduce::reset() {
}
// Returns this accumulator serialized as a Value along with the reduce function.
-Document AccumulatorInternalJsReduce::serialize(boost::intrusive_ptr<Expression> expression,
+Document AccumulatorInternalJsReduce::serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
bool explain) const {
- return DOC(
- getOpName() << DOC("data" << expression->serialize(explain) << "eval" << _funcSource));
+ return DOC(getOpName() << DOC("data" << argument->serialize(explain) << "eval" << _funcSource));
}
+
+REGISTER_ACCUMULATOR(accumulator, AccumulatorJs::parse);
+
+boost::intrusive_ptr<AccumulatorState> AccumulatorJs::create(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ std::string init,
+ std::string accumulate,
+ std::string merge,
+ std::string finalize) {
+ return new AccumulatorJs(
+ expCtx, std::move(init), std::move(accumulate), std::move(merge), std::move(finalize));
+}
+
+namespace {
+// Parses a constant expression of type String or Code.
+std::string parseFunction(StringData fieldName,
+ boost::intrusive_ptr<ExpressionContext> expCtx,
+ BSONElement elem,
+ VariablesParseState vps) {
+ boost::intrusive_ptr<Expression> expr = Expression::parseOperand(expCtx, elem, vps);
+ expr = expr->optimize();
+ ExpressionConstant* ec = dynamic_cast<ExpressionConstant*>(expr.get());
+ uassert(4544701,
+ str::stream() << "$accumulator '" << fieldName << "' must be a constant expression",
+ ec);
+ Value v = ec->getValue();
+ uassert(4544702,
+ str::stream() << "$accumulator '" << fieldName << "' must be a String or Code",
+ v.getType() == BSONType::String || v.getType() == BSONType::Code);
+ return v.coerceToString();
+}
+} // namespace
+
+
+Document AccumulatorJs::serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ bool explain) const {
+ MutableDocument args;
+ args.addField("init", Value(_init));
+ args.addField("initArgs", Value(initializer->serialize(explain)));
+ args.addField("accumulate", Value(_accumulate));
+ args.addField("accumulateArgs", Value(argument->serialize(explain)));
+ args.addField("merge", Value(_merge));
+ args.addField("finalize", Value(_finalize));
+ args.addField("lang", Value("js"_sd));
+ return DOC(getOpName() << args.freeze());
+}
+
+AccumulationExpression AccumulatorJs::parse(boost::intrusive_ptr<ExpressionContext> expCtx,
+ BSONElement elem,
+ VariablesParseState vps) {
+ /*
+ * {$accumulator: {
+ * init: <code>,
+ * accumulate: <code>,
+ * merge: <code>,
+ * finalize: <code>,
+ *
+ * accumulateArgs: <expr>, // evaluated once per document
+ *
+ * initArgs: <expr>, // evaluated once per group
+ *
+ * lang: 'js',
+ * }}
+ */
+ uassert(4544703,
+ str::stream() << "$accumulator expects an object as an argument; found: "
+ << typeName(elem.type()),
+ elem.type() == BSONType::Object);
+ BSONObj obj = elem.embeddedObject();
+
+ std::string init, accumulate, merge, finalize;
+ boost::intrusive_ptr<Expression> initArgs, accumulateArgs;
+
+ for (auto&& element : obj) {
+ auto name = element.fieldNameStringData();
+ if (name == "init") {
+ init = parseFunction("init", expCtx, element, vps);
+ } else if (name == "accumulate") {
+ accumulate = parseFunction("accumulate", expCtx, element, vps);
+ } else if (name == "merge") {
+ merge = parseFunction("merge", expCtx, element, vps);
+ } else if (name == "finalize") {
+ finalize = parseFunction("finalize", expCtx, element, vps);
+ } else if (name == "initArgs") {
+ initArgs = Expression::parseOperand(expCtx, element, vps);
+ } else if (name == "accumulateArgs") {
+ accumulateArgs = Expression::parseOperand(expCtx, element, vps);
+ } else if (name == "lang") {
+ uassert(4544704,
+ str::stream() << "$accumulator lang must be a string; found: "
+ << element.type(),
+ element.type() == BSONType::String);
+ uassert(4544705,
+ "$accumulator only supports lang: 'js'",
+ element.valueStringData() == "js");
+ } else {
+ // unexpected field
+ uassert(
+ 4544706, str::stream() << "$accumulator got an unexpected field: " << name, false);
+ }
+ }
+ uassert(4544707, "$accumulator missing required argument 'init'", !init.empty());
+ uassert(4544708, "$accumulator missing required argument 'accumulate'", !accumulate.empty());
+ uassert(4544709, "$accumulator missing required argument 'merge'", !merge.empty());
+ if (finalize.empty()) {
+ // finalize is optional because many custom accumulators will return the final state
+ // unchanged.
+ finalize = "function(state) { return state; }";
+ }
+ if (!initArgs) {
+ // initArgs is optional because most custom accumulators don't need the state to depend on
+ // the group key.
+ initArgs = ExpressionConstant::create(expCtx, Value(BSONArray()));
+ }
+ // accumulateArgs is required because it's the only way to communicate a value from the input
+ // stream into the accumulator state.
+ uassert(4544710, "$accumulator missing required argument 'accumulateArgs'", accumulateArgs);
+
+ auto factory = [expCtx = expCtx,
+ init = std::move(init),
+ accumulate = std::move(accumulate),
+ merge = std::move(merge),
+ finalize = std::move(finalize)]() {
+ return new AccumulatorJs(expCtx, init, accumulate, merge, finalize);
+ };
+ return {std::move(initArgs), std::move(accumulateArgs), std::move(factory)};
+}
+
+Value AccumulatorJs::getValue(bool toBeMerged) {
+ // _state is initialized when we encounter the first document in each group. We never create
+ // empty groups: even in a {$group: {_id: 1, ...}}, we will return zero groups rather than one
+ // empty group.
+ invariant(_state);
+
+ // If toBeMerged then we return the current state, to be fed back in to accumulate / merge /
+ // finalize later. If not toBeMerged then we return the final value, by calling finalize.
+ if (toBeMerged) {
+ return *_state;
+ }
+
+ // Get the final value given the current accumulator state.
+
+ auto& expCtx = getExpressionContext();
+ auto jsExec = expCtx->getJsExecWithScope();
+ auto func = makeJsFunc(expCtx, _finalize);
+
+ return jsExec->callFunction(func, BSON_ARRAY(*_state), {});
+}
+
+void AccumulatorJs::startNewGroup(Value const& input) {
+ // Between groups the _state should be empty: we initialize it to be empty it in the
+ // constructor, and we clear it at the end of each group (in .reset()).
+ invariant(!_state);
+
+ auto& expCtx = getExpressionContext();
+ auto jsExec = expCtx->getJsExecWithScope();
+ auto func = makeJsFunc(expCtx, _init);
+
+ // input is a value produced by our AccumulationExpression::initializer.
+ uassert(4544711,
+ str::stream() << "$accumulator initArgs must evaluate to an array: "
+ << input.toString(),
+ input.getType() == BSONType::Array);
+
+ size_t index = 0;
+ BSONArrayBuilder bob;
+ for (auto&& arg : input.getArray()) {
+ arg.addToBsonArray(&bob, index++);
+ }
+
+ _state = jsExec->callFunction(func, bob.arr(), {});
+
+ recomputeMemUsageBytes();
+}
+
+void AccumulatorJs::reset() {
+ _state = std::nullopt;
+ recomputeMemUsageBytes();
+}
+
+void AccumulatorJs::processInternal(const Value& input, bool merging) {
+ // _state should be nonempty because we populate it in startNewGroup.
+ invariant(_state);
+
+ auto& expCtx = getExpressionContext();
+ auto jsExec = expCtx->getJsExecWithScope();
+
+ if (merging) {
+ // input is an intermediate state from another instance of this kind of accumulator. Call
+ // the user's merge function.
+ auto func = makeJsFunc(expCtx, _merge);
+ _state = jsExec->callFunction(func, BSON_ARRAY(*_state << input), {});
+ recomputeMemUsageBytes();
+ } else {
+ // input is a value produced by our AccumulationExpression::argument. Call the user's
+ // accumulate function.
+ auto func = makeJsFunc(expCtx, _accumulate);
+ uassert(4544712,
+ str::stream() << "$accumulator accumulateArgs must evaluate to an array: "
+ << input.toString(),
+ input.getType() == BSONType::Array);
+
+ size_t index = 0;
+ BSONArrayBuilder bob;
+ _state->addToBsonArray(&bob, index++);
+ for (auto&& arg : input.getArray()) {
+ arg.addToBsonArray(&bob, index++);
+ }
+
+ _state = jsExec->callFunction(func, bob.done(), {});
+ recomputeMemUsageBytes();
+ }
+}
+
+void AccumulatorJs::recomputeMemUsageBytes() {
+ auto stateSize = _state.value_or(Value{}).getApproximateSize();
+ uassert(4544713,
+ str::stream() << "$accumulator state exceeded max BSON size: " << stateSize,
+ stateSize <= BSONObjMaxUserSize);
+ _memUsageBytes = sizeof(*this) + stateSize + _init.capacity() + _accumulate.capacity() +
+ _merge.capacity() + _finalize.capacity();
+}
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_js_reduce.h b/src/mongo/db/pipeline/accumulator_js_reduce.h
index bc2725a5567..1bb15f948e8 100644
--- a/src/mongo/db/pipeline/accumulator_js_reduce.h
+++ b/src/mongo/db/pipeline/accumulator_js_reduce.h
@@ -38,19 +38,19 @@
namespace mongo {
-class AccumulatorInternalJsReduce final : public Accumulator {
+class AccumulatorInternalJsReduce final : public AccumulatorState {
public:
static constexpr auto kAccumulatorName = "$_internalJsReduce"_sd;
- static boost::intrusive_ptr<Accumulator> create(
+ static boost::intrusive_ptr<AccumulatorState> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData funcSource);
- static std::pair<boost::intrusive_ptr<Expression>, Accumulator::Factory> parseInternalJsReduce(
+ static AccumulationExpression parseInternalJsReduce(
boost::intrusive_ptr<ExpressionContext> expCtx, BSONElement elem, VariablesParseState vps);
AccumulatorInternalJsReduce(const boost::intrusive_ptr<ExpressionContext>& expCtx,
StringData funcSource)
- : Accumulator(expCtx), _funcSource(funcSource) {
+ : AccumulatorState(expCtx), _funcSource(funcSource) {
_memUsageBytes = sizeof(*this);
}
@@ -64,7 +64,8 @@ public:
void reset() final;
- virtual Document serialize(boost::intrusive_ptr<Expression> expression,
+ virtual Document serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
bool explain) const override;
private:
@@ -75,4 +76,59 @@ private:
Value _key;
};
+class AccumulatorJs final : public AccumulatorState {
+public:
+ static constexpr auto kAccumulatorName = "$accumulator"_sd;
+ const char* getOpName() const final {
+ return kAccumulatorName.rawData();
+ }
+
+ // An AccumulatorState instance only owns its "static" arguments: those that don't need to be
+ // evaluated per input document.
+ static boost::intrusive_ptr<AccumulatorState> create(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ std::string init,
+ std::string accumulate,
+ std::string merge,
+ std::string finalize);
+
+ static AccumulationExpression parse(boost::intrusive_ptr<ExpressionContext> expCtx,
+ BSONElement elem,
+ VariablesParseState vps);
+
+ Value getValue(bool toBeMerged) final;
+ void reset() final;
+ void processInternal(const Value& input, bool merging) final;
+
+ Document serialize(boost::intrusive_ptr<Expression> initializer,
+ boost::intrusive_ptr<Expression> argument,
+ bool explain) const final;
+ void startNewGroup(Value const& input) final;
+
+private:
+ AccumulatorJs(const boost::intrusive_ptr<ExpressionContext>& expCtx,
+ std::string init,
+ std::string accumulate,
+ std::string merge,
+ std::string finalize)
+ : AccumulatorState(expCtx),
+ _init(init),
+ _accumulate(accumulate),
+ _merge(merge),
+ _finalize(finalize) {
+ recomputeMemUsageBytes();
+ }
+ void recomputeMemUsageBytes();
+
+ // static arguments
+ std::string _init, _accumulate, _merge, _finalize;
+
+ // accumulator state during execution
+ // - When the accumulator is first created, _state is empty.
+ // - When the accumulator is fed its first input Value, it runs the user init and accumulate
+ // functions, and _state gets a Value.
+ // - When the accumulator is reset, _state becomes empty again.
+ std::optional<Value> _state;
+};
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_last.cpp b/src/mongo/db/pipeline/accumulator_last.cpp
index 150360d4fdd..14362e42cab 100644
--- a/src/mongo/db/pipeline/accumulator_last.cpp
+++ b/src/mongo/db/pipeline/accumulator_last.cpp
@@ -55,7 +55,7 @@ Value AccumulatorLast::getValue(bool toBeMerged) {
}
AccumulatorLast::AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : Accumulator(expCtx) {
+ : AccumulatorState(expCtx) {
_memUsageBytes = sizeof(*this);
}
@@ -64,7 +64,7 @@ void AccumulatorLast::reset() {
_last = Value();
}
-intrusive_ptr<Accumulator> AccumulatorLast::create(
+intrusive_ptr<AccumulatorState> AccumulatorLast::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorLast(expCtx);
}
diff --git a/src/mongo/db/pipeline/accumulator_merge_objects.cpp b/src/mongo/db/pipeline/accumulator_merge_objects.cpp
index d1e6310ea23..6b23ae528a1 100644
--- a/src/mongo/db/pipeline/accumulator_merge_objects.cpp
+++ b/src/mongo/db/pipeline/accumulator_merge_objects.cpp
@@ -49,14 +49,14 @@ const char* AccumulatorMergeObjects::getOpName() const {
return "$mergeObjects";
}
-intrusive_ptr<Accumulator> AccumulatorMergeObjects::create(
+intrusive_ptr<AccumulatorState> AccumulatorMergeObjects::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorMergeObjects(expCtx);
}
AccumulatorMergeObjects::AccumulatorMergeObjects(
const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : Accumulator(expCtx) {
+ : AccumulatorState(expCtx) {
_memUsageBytes = sizeof(*this);
}
diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp
index c1af0ff8226..265a84b7075 100644
--- a/src/mongo/db/pipeline/accumulator_min_max.cpp
+++ b/src/mongo/db/pipeline/accumulator_min_max.cpp
@@ -71,7 +71,7 @@ Value AccumulatorMinMax::getValue(bool toBeMerged) {
AccumulatorMinMax::AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx,
Sense sense)
- : Accumulator(expCtx), _sense(sense) {
+ : AccumulatorState(expCtx), _sense(sense) {
_memUsageBytes = sizeof(*this);
}
@@ -80,12 +80,12 @@ void AccumulatorMinMax::reset() {
_memUsageBytes = sizeof(*this);
}
-intrusive_ptr<Accumulator> AccumulatorMin::create(
+intrusive_ptr<AccumulatorState> AccumulatorMin::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorMin(expCtx);
}
-intrusive_ptr<Accumulator> AccumulatorMax::create(
+intrusive_ptr<AccumulatorState> AccumulatorMax::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorMax(expCtx);
}
diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp
index d13a942ef62..3a004dd9731 100644
--- a/src/mongo/db/pipeline/accumulator_push.cpp
+++ b/src/mongo/db/pipeline/accumulator_push.cpp
@@ -83,7 +83,7 @@ Value AccumulatorPush::getValue(bool toBeMerged) {
AccumulatorPush::AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx,
boost::optional<int> maxMemoryUsageBytes)
- : Accumulator(expCtx),
+ : AccumulatorState(expCtx),
_maxMemUsageBytes(maxMemoryUsageBytes.value_or(internalQueryMaxPushBytes.load())) {
_memUsageBytes = sizeof(*this);
}
@@ -93,7 +93,7 @@ void AccumulatorPush::reset() {
_memUsageBytes = sizeof(*this);
}
-intrusive_ptr<Accumulator> AccumulatorPush::create(
+intrusive_ptr<AccumulatorState> AccumulatorPush::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorPush(expCtx, boost::none);
}
diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp
index 55367d766be..cdd31b2c897 100644
--- a/src/mongo/db/pipeline/accumulator_std_dev.cpp
+++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp
@@ -96,20 +96,20 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) {
}
}
-intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create(
+intrusive_ptr<AccumulatorState> AccumulatorStdDevSamp::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorStdDevSamp(expCtx);
}
-intrusive_ptr<Accumulator> AccumulatorStdDevPop::create(
+intrusive_ptr<AccumulatorState> AccumulatorStdDevPop::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorStdDevPop(expCtx);
}
AccumulatorStdDev::AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx,
bool isSamp)
- : Accumulator(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) {
- // This is a fixed size Accumulator so we never need to update this
+ : AccumulatorState(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) {
+ // This is a fixed size AccumulatorState so we never need to update this
_memUsageBytes = sizeof(*this);
}
diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp
index 6cd34d8c76f..182f592ce3f 100644
--- a/src/mongo/db/pipeline/accumulator_sum.cpp
+++ b/src/mongo/db/pipeline/accumulator_sum.cpp
@@ -85,7 +85,7 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) {
}
}
-intrusive_ptr<Accumulator> AccumulatorSum::create(
+intrusive_ptr<AccumulatorState> AccumulatorSum::create(
const boost::intrusive_ptr<ExpressionContext>& expCtx) {
return new AccumulatorSum(expCtx);
}
@@ -128,8 +128,8 @@ Value AccumulatorSum::getValue(bool toBeMerged) {
}
AccumulatorSum::AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : Accumulator(expCtx) {
- // This is a fixed size Accumulator so we never need to update this.
+ : AccumulatorState(expCtx) {
+ // This is a fixed size AccumulatorState so we never need to update this.
_memUsageBytes = sizeof(*this);
}
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index 373a24a4155..724e6a6838a 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -46,9 +46,9 @@ using std::numeric_limits;
using std::string;
/**
- * Takes the name of an Accumulator as its template argument and a list of pairs of arguments and
- * expected results as its second argument, and asserts that for the given Accumulator the arguments
- * evaluate to the expected results.
+ * Takes the name of an AccumulatorState as its template argument and a list of pairs of arguments
+ * and expected results as its second argument, and asserts that for the given AccumulatorState the
+ * arguments evaluate to the expected results.
*/
template <typename AccName>
static void assertExpectedResults(
diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.cpp b/src/mongo/db/pipeline/document_source_bucket_auto.cpp
index eba38c5c58b..c6decbca69c 100644
--- a/src/mongo/db/pipeline/document_source_bucket_auto.cpp
+++ b/src/mongo/db/pipeline/document_source_bucket_auto.cpp
@@ -109,13 +109,25 @@ DocumentSource::GetNextResult DocumentSourceBucketAuto::doGetNext() {
return makeDocument(*(_bucketsIterator++));
}
+boost::intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::optimize() {
+ _groupByExpression = _groupByExpression->optimize();
+ for (auto&& accumulatedField : _accumulatedFields) {
+ accumulatedField.expr.argument = accumulatedField.expr.argument->optimize();
+ accumulatedField.expr.initializer = accumulatedField.expr.initializer->optimize();
+ }
+ return this;
+}
+
DepsTracker::State DocumentSourceBucketAuto::getDependencies(DepsTracker* deps) const {
// Add the 'groupBy' expression.
_groupByExpression->addDependencies(deps);
// Add the 'output' fields.
for (auto&& accumulatedField : _accumulatedFields) {
- accumulatedField.expression->addDependencies(deps);
+ // Anything the per-doc expression depends on, the whole stage depends on.
+ accumulatedField.expr.argument->addDependencies(deps);
+ // The initializer should be an ExpressionConstant, or something that optimizes to one.
+ // ExpressionConstant doesn't have dependencies.
}
// We know exactly which fields will be present in the output document. Future stages cannot
@@ -189,7 +201,8 @@ void DocumentSourceBucketAuto::addDocumentToBucket(const pair<Value, Document>&
const size_t numAccumulators = _accumulatedFields.size();
for (size_t k = 0; k < numAccumulators; k++) {
bucket._accums[k]->process(
- _accumulatedFields[k].expression->evaluate(entry.second, &pExpCtx->variables), false);
+ _accumulatedFields[k].expr.argument->evaluate(entry.second, &pExpCtx->variables),
+ false);
}
}
@@ -234,6 +247,16 @@ void DocumentSourceBucketAuto::populateBuckets() {
// Initialize the current bucket.
Bucket currentBucket(pExpCtx, currentValue.first, currentValue.first, _accumulatedFields);
+ // Evaluate each initializer against an empty document. Normally the
+ // initializer can refer to the group key, but in $bucketAuto there is no single
+ // group key per bucket.
+ Document emptyDoc;
+ for (size_t k = 0; k < _accumulatedFields.size(); ++k) {
+ Value initializerValue =
+ _accumulatedFields[k].expr.initializer->evaluate(emptyDoc, &pExpCtx->variables);
+ currentBucket._accums[k]->startNewGroup(initializerValue);
+ }
+
// Add the first value into the current bucket.
addDocumentToBucket(currentValue, currentBucket);
@@ -382,10 +405,11 @@ Value DocumentSourceBucketAuto::serialize(
MutableDocument outputSpec(_accumulatedFields.size());
for (auto&& accumulatedField : _accumulatedFields) {
- intrusive_ptr<Accumulator> accum = accumulatedField.makeAccumulator();
+ intrusive_ptr<AccumulatorState> accum = accumulatedField.makeAccumulator();
outputSpec[accumulatedField.fieldName] =
- Value{Document{{accum->getOpName(),
- accumulatedField.expression->serialize(static_cast<bool>(explain))}}};
+ Value(accum->serialize(accumulatedField.expr.initializer,
+ accumulatedField.expr.argument,
+ static_cast<bool>(explain)));
}
insides["output"] = outputSpec.freezeToValue();
@@ -405,9 +429,11 @@ intrusive_ptr<DocumentSourceBucketAuto> DocumentSourceBucketAuto::create(
numBuckets > 0);
// If there is no output field specified, then add the default one.
if (accumulationStatements.empty()) {
- accumulationStatements.emplace_back("count",
- ExpressionConstant::create(pExpCtx, Value(1)),
- [pExpCtx] { return AccumulatorSum::create(pExpCtx); });
+ accumulationStatements.emplace_back(
+ "count",
+ AccumulationExpression(ExpressionConstant::create(pExpCtx, Value(BSONNULL)),
+ ExpressionConstant::create(pExpCtx, Value(1)),
+ [pExpCtx] { return AccumulatorSum::create(pExpCtx); }));
}
return new DocumentSourceBucketAuto(pExpCtx,
groupByExpression,
@@ -486,8 +512,13 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson(
argument.type() == BSONType::Object);
for (auto&& outputField : argument.embeddedObject()) {
- accumulationStatements.push_back(
- AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps));
+ auto stmt =
+ AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps);
+ stmt.expr.initializer = stmt.expr.initializer->optimize();
+ uassert(4544714,
+ "Can't refer to the group key in $bucketAuto",
+ ExpressionConstant::isNullOrConstant(stmt.expr.initializer));
+ accumulationStatements.push_back(std::move(stmt));
}
} else if ("granularity" == argName) {
uassert(40261,
diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.h b/src/mongo/db/pipeline/document_source_bucket_auto.h
index 8804b0df6c1..fcc35a130d3 100644
--- a/src/mongo/db/pipeline/document_source_bucket_auto.h
+++ b/src/mongo/db/pipeline/document_source_bucket_auto.h
@@ -49,6 +49,7 @@ public:
DepsTracker::State getDependencies(DepsTracker* deps) const final;
const char* getSourceName() const final;
+ boost::intrusive_ptr<DocumentSource> optimize() final;
StageConstraints constraints(Pipeline::SplitState pipeState) const final {
return {StreamType::kBlocking,
@@ -114,7 +115,7 @@ private:
const std::vector<AccumulationStatement>& accumulationStatements);
Value _min;
Value _max;
- std::vector<boost::intrusive_ptr<Accumulator>> _accums;
+ std::vector<boost::intrusive_ptr<AccumulatorState>> _accums;
};
/**
diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp
index 0f961f8c363..8268f0a1dd1 100644
--- a/src/mongo/db/pipeline/document_source_group.cpp
+++ b/src/mongo/db/pipeline/document_source_group.cpp
@@ -158,6 +158,17 @@ DocumentSource::GetNextResult DocumentSourceGroup::getNextSpilled() {
_currentId = _firstPartOfNextGroup.first;
const size_t numAccumulators = _accumulatedFields.size();
+
+ // Call startNewGroup on every accumulator.
+ Value expandedId = expandId(_currentId);
+ Document idDoc =
+ expandedId.getType() == BSONType::Object ? expandedId.getDocument() : Document();
+ for (size_t i = 0; i < numAccumulators; ++i) {
+ Value initializerValue =
+ _accumulatedFields[i].expr.initializer->evaluate(idDoc, &pExpCtx->variables);
+ _currentAccumulators[i]->startNewGroup(initializerValue);
+ }
+
while (pExpCtx->getValueComparator().evaluate(_currentId == _firstPartOfNextGroup.first)) {
// Inside of this loop, _firstPartOfNextGroup is the current data being processed.
// At loop exit, it is the first value to be processed in the next group.
@@ -216,7 +227,8 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() {
}
for (auto&& accumulatedField : _accumulatedFields) {
- accumulatedField.expression = accumulatedField.expression->optimize();
+ accumulatedField.expr.initializer = accumulatedField.expr.initializer->optimize();
+ accumulatedField.expr.argument = accumulatedField.expr.argument->optimize();
}
return this;
@@ -241,9 +253,11 @@ Value DocumentSourceGroup::serialize(boost::optional<ExplainOptions::Verbosity>
// Add the remaining fields.
for (auto&& accumulatedField : _accumulatedFields) {
- intrusive_ptr<Accumulator> accum = accumulatedField.makeAccumulator();
+ intrusive_ptr<AccumulatorState> accum = accumulatedField.makeAccumulator();
insides[accumulatedField.fieldName] =
- Value(accum->serialize(accumulatedField.expression, static_cast<bool>(explain)));
+ Value(accum->serialize(accumulatedField.expr.initializer,
+ accumulatedField.expr.argument,
+ static_cast<bool>(explain)));
}
if (_doingMerge) {
@@ -263,7 +277,8 @@ DepsTracker::State DocumentSourceGroup::getDependencies(DepsTracker* deps) const
// add the rest
for (auto&& accumulatedField : _accumulatedFields) {
- accumulatedField.expression->addDependencies(deps);
+ accumulatedField.expr.argument->addDependencies(deps);
+ // Don't add initializer, because it doesn't refer to docs from the input stream.
}
return DepsTracker::State::EXHAUSTIVE_ALL;
@@ -485,16 +500,23 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() {
// accumulator. This is done in a somewhat odd way in order to avoid hashing 'id' and
// looking it up in '_groups' multiple times.
const size_t oldSize = _groups->size();
- vector<intrusive_ptr<Accumulator>>& group = (*_groups)[id];
+ vector<intrusive_ptr<AccumulatorState>>& group = (*_groups)[id];
const bool inserted = _groups->size() != oldSize;
if (inserted) {
_memoryUsageBytes += id.getApproximateSize();
- // Add the accumulators
+ // Initialize and add the accumulators
+ Value expandedId = expandId(id);
+ Document idDoc =
+ expandedId.getType() == BSONType::Object ? expandedId.getDocument() : Document();
group.reserve(numAccumulators);
for (auto&& accumulatedField : _accumulatedFields) {
- group.push_back(accumulatedField.makeAccumulator());
+ auto accum = accumulatedField.makeAccumulator();
+ Value initializerValue =
+ accumulatedField.expr.initializer->evaluate(idDoc, &pExpCtx->variables);
+ accum->startNewGroup(initializerValue);
+ group.push_back(accum);
}
} else {
for (auto&& groupObj : group) {
@@ -508,7 +530,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() {
for (size_t i = 0; i < numAccumulators; i++) {
group[i]->process(
- _accumulatedFields[i].expression->evaluate(rootDocument, &pExpCtx->variables),
+ _accumulatedFields[i].expr.argument->evaluate(rootDocument, &pExpCtx->variables),
_doingMerge);
_memoryUsageBytes += group[i]->memUsageForSorter();
@@ -693,7 +715,7 @@ boost::optional<DocumentSource::DistributedPlanLogic> DocumentSourceGroup::distr
// original accumulator may be collecting an expression based on a field expression or
// constant. Here, we accumulate the output of the same name from the prior group.
auto copiedAccumulatedField = accumulatedField;
- copiedAccumulatedField.expression =
+ copiedAccumulatedField.expr.argument =
ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + copiedAccumulatedField.fieldName, vps);
mergingGroup->addAccumulator(copiedAccumulatedField);
}
@@ -775,7 +797,10 @@ DocumentSourceGroup::rewriteGroupAsTransformOnFirstDocument() const {
fields.push_back(std::make_pair("_id", ExpressionFieldPath::create(pExpCtx, groupId)));
for (auto&& accumulator : _accumulatedFields) {
- fields.push_back(std::make_pair(accumulator.fieldName, accumulator.expression));
+ fields.push_back(std::make_pair(accumulator.fieldName, accumulator.expr.argument));
+
+ // Since we don't attempt this transformation for non-$first accumulators,
+ // the initializer should always be trivial.
}
return GroupFromFirstDocumentTransformation::create(pExpCtx, groupId, std::move(fields));
diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h
index 88b0a669b6e..5d885e84a1f 100644
--- a/src/mongo/db/pipeline/document_source_group.h
+++ b/src/mongo/db/pipeline/document_source_group.h
@@ -88,7 +88,7 @@ private:
class DocumentSourceGroup final : public DocumentSource {
public:
- using Accumulators = std::vector<boost::intrusive_ptr<Accumulator>>;
+ using Accumulators = std::vector<boost::intrusive_ptr<AccumulatorState>>;
using GroupsMap = ValueUnorderedMap<Accumulators>;
static constexpr StringData kStageName = "$group"_sd;
diff --git a/src/mongo/db/pipeline/document_source_group_test.cpp b/src/mongo/db/pipeline/document_source_group_test.cpp
index a2ed4358248..c3c764c4f6a 100644
--- a/src/mongo/db/pipeline/document_source_group_test.cpp
+++ b/src/mongo/db/pipeline/document_source_group_test.cpp
@@ -75,9 +75,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoading) {
// This is the only way to do this in a debug build.
auto&& parser = AccumulationStatement::getParser("$sum");
auto accumulatorArg = BSON("" << 1);
- auto [expression, factory] =
- parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
- AccumulationStatement countStatement{"count", expression, factory};
+ auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
+ AccumulationStatement countStatement{"count", accExpr};
auto group = DocumentSourceGroup::create(
expCtx, ExpressionConstant::create(expCtx, Value(BSONNULL)), {countStatement});
auto mock =
@@ -113,9 +112,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoadingWhileSpilled) {
auto&& parser = AccumulationStatement::getParser("$push");
auto accumulatorArg = BSON(""
<< "$largeStr");
- auto [expression, factory] =
- parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
- AccumulationStatement pushStatement{"spaceHog", expression, factory};
+ auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
+ AccumulationStatement pushStatement{"spaceHog", accExpr};
auto groupByExpression =
ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState);
auto group = DocumentSourceGroup::create(
@@ -156,9 +154,8 @@ TEST_F(DocumentSourceGroupTest, ShouldErrorIfNotAllowedToSpillToDiskAndResultSet
auto&& parser = AccumulationStatement::getParser("$push");
auto accumulatorArg = BSON(""
<< "$largeStr");
- auto [expression, factory] =
- parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
- AccumulationStatement pushStatement{"spaceHog", expression, factory};
+ auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
+ AccumulationStatement pushStatement{"spaceHog", accExpr};
auto groupByExpression =
ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState);
auto group = DocumentSourceGroup::create(
@@ -181,9 +178,8 @@ TEST_F(DocumentSourceGroupTest, ShouldCorrectlyTrackMemoryUsageBetweenPauses) {
auto&& parser = AccumulationStatement::getParser("$push");
auto accumulatorArg = BSON(""
<< "$largeStr");
- auto [expression, factory] =
- parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
- AccumulationStatement pushStatement{"spaceHog", expression, factory};
+ auto accExpr = parser(expCtx, accumulatorArg.firstElement(), expCtx->variablesParseState);
+ AccumulationStatement pushStatement{"spaceHog", accExpr};
auto groupByExpression =
ExpressionFieldPath::parse(expCtx, "$_id", expCtx->variablesParseState);
auto group = DocumentSourceGroup::create(
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index b05b4dcd918..4f772ca2363 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -399,15 +399,15 @@ public:
* Used to make Accumulators available as Expressions, e.g., to make $sum available as an Expression
* use "REGISTER_EXPRESSION(sum, ExpressionAccumulator<AccumulatorSum>::parse);".
*/
-template <typename Accumulator>
+template <typename AccumulatorState>
class ExpressionFromAccumulator
- : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> {
+ : public ExpressionVariadic<ExpressionFromAccumulator<AccumulatorState>> {
public:
explicit ExpressionFromAccumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx)
- : ExpressionVariadic<ExpressionFromAccumulator<Accumulator>>(expCtx) {}
+ : ExpressionVariadic<ExpressionFromAccumulator<AccumulatorState>>(expCtx) {}
Value evaluate(const Document& root, Variables* variables) const final {
- Accumulator accum(this->getExpressionContext());
+ AccumulatorState accum(this->getExpressionContext());
const auto n = this->_children.size();
// If a single array arg is given, loop through it passing each member to the accumulator.
// If a single, non-array arg is given, pass it directly to the accumulator.
@@ -435,15 +435,15 @@ public:
if (this->_children.size() == 1) {
return false;
}
- return Accumulator(this->getExpressionContext()).isAssociative();
+ return AccumulatorState(this->getExpressionContext()).isAssociative();
}
bool isCommutative() const final {
- return Accumulator(this->getExpressionContext()).isCommutative();
+ return AccumulatorState(this->getExpressionContext()).isCommutative();
}
const char* getOpName() const final {
- return Accumulator(this->getExpressionContext()).getOpName();
+ return AccumulatorState(this->getExpressionContext()).getOpName();
}
void acceptVisitor(ExpressionVisitor* visitor) final {
diff --git a/src/mongo/db/pipeline/expression_visitor.h b/src/mongo/db/pipeline/expression_visitor.h
index 326c8e2e6ce..393778e5f21 100644
--- a/src/mongo/db/pipeline/expression_visitor.h
+++ b/src/mongo/db/pipeline/expression_visitor.h
@@ -158,7 +158,7 @@ class AccumulatorStdDevPop;
class AccumulatorStdDevSamp;
class AccumulatorSum;
class AccumulatorMergeObjects;
-template <typename Accumulator>
+template <typename AccumulatorState>
class ExpressionFromAccumulator;
/**