summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/pipeline/SConscript2
-rw-r--r--src/mongo/db/pipeline/accumulator.h48
-rw-r--r--src/mongo/db/pipeline/accumulator_avg.cpp2
-rw-r--r--src/mongo/db/pipeline/accumulator_min_max.cpp18
-rw-r--r--src/mongo/db/pipeline/accumulator_std_dev.cpp15
-rw-r--r--src/mongo/db/pipeline/accumulator_sum.cpp2
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp16
-rw-r--r--src/mongo/db/pipeline/expression.cpp10
-rw-r--r--src/mongo/db/pipeline/expression.h44
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp85
10 files changed, 210 insertions, 32 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 555aaaeceb2..f849620a5a9 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -87,6 +87,7 @@ env.Library(
],
LIBDEPS=[
'document_value',
+ 'expression',
'field_path',
]
)
@@ -141,6 +142,7 @@ env.CppUnitTest(
target='agg_expression_test',
source='expression_test.cpp',
LIBDEPS=[
+ 'accumulator',
'expression',
],
)
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index 64e8aed036c..77c3ecf63fc 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -99,6 +99,10 @@ public:
*/
static Factory getFactory(StringData name);
+ virtual bool isAssociativeAndCommutative() const {
+ return false;
+ }
+
protected:
/// Update subclass's internal state based on input
virtual void processInternal(const Value& input, bool merging) = 0;
@@ -119,6 +123,10 @@ public:
static boost::intrusive_ptr<Accumulator> create();
+ bool isAssociativeAndCommutative() const final {
+ return true;
+ }
+
private:
typedef std::unordered_set<Value, Value::Hash> SetType;
SetType set;
@@ -169,6 +177,10 @@ public:
static boost::intrusive_ptr<Accumulator> create();
+ bool isAssociativeAndCommutative() const final {
+ return true;
+ }
+
private:
BSONType totalType;
long long longTotal;
@@ -176,7 +188,7 @@ private:
};
-class AccumulatorMinMax final : public Accumulator {
+class AccumulatorMinMax : public Accumulator {
public:
enum Sense : int {
MIN = 1,
@@ -190,14 +202,27 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> createMin();
- static boost::intrusive_ptr<Accumulator> createMax();
+ bool isAssociativeAndCommutative() const final {
+ return true;
+ }
private:
Value _val;
const Sense _sense;
};
+class AccumulatorMax final : public AccumulatorMinMax {
+public:
+ AccumulatorMax() : AccumulatorMinMax(MAX) {}
+ static boost::intrusive_ptr<Accumulator> create();
+};
+
+class AccumulatorMin final : public AccumulatorMinMax {
+public:
+ AccumulatorMin() : AccumulatorMinMax(MIN) {}
+ static boost::intrusive_ptr<Accumulator> create();
+};
+
class AccumulatorPush final : public Accumulator {
public:
@@ -232,7 +257,7 @@ private:
};
-class AccumulatorStdDev final : public Accumulator {
+class AccumulatorStdDev : public Accumulator {
public:
explicit AccumulatorStdDev(bool isSamp);
@@ -241,13 +266,22 @@ public:
const char* getOpName() const final;
void reset() final;
- static boost::intrusive_ptr<Accumulator> createSamp();
- static boost::intrusive_ptr<Accumulator> createPop();
-
private:
const bool _isSamp;
long long _count;
double _mean;
double _m2; // Running sum of squares of delta from mean. Named to match algorithm.
};
+
+class AccumulatorStdDevPop final : public AccumulatorStdDev {
+public:
+ AccumulatorStdDevPop() : AccumulatorStdDev(false) {}
+ static boost::intrusive_ptr<Accumulator> create();
+};
+
+class AccumulatorStdDevSamp final : public AccumulatorStdDev {
+public:
+ AccumulatorStdDevSamp() : AccumulatorStdDev(true) {}
+ static boost::intrusive_ptr<Accumulator> create();
+};
}
diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp
index 06dd585bdf7..ed11d81ecc0 100644
--- a/src/mongo/db/pipeline/accumulator_avg.cpp
+++ b/src/mongo/db/pipeline/accumulator_avg.cpp
@@ -30,6 +30,7 @@
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
@@ -38,6 +39,7 @@ namespace mongo {
using boost::intrusive_ptr;
REGISTER_ACCUMULATOR(avg, AccumulatorAvg::create);
+REGISTER_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse);
const char* AccumulatorAvg::getOpName() const {
return "$avg";
diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp
index 9bfbf7b380b..fc854ca1ddb 100644
--- a/src/mongo/db/pipeline/accumulator_min_max.cpp
+++ b/src/mongo/db/pipeline/accumulator_min_max.cpp
@@ -29,14 +29,17 @@
#include "mongo/platform/basic.h"
#include "mongo/db/pipeline/accumulator.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/value.h"
namespace mongo {
using boost::intrusive_ptr;
-REGISTER_ACCUMULATOR(max, AccumulatorMinMax::createMax);
-REGISTER_ACCUMULATOR(min, AccumulatorMinMax::createMin);
+REGISTER_ACCUMULATOR(max, AccumulatorMax::create);
+REGISTER_ACCUMULATOR(min, AccumulatorMin::create);
+REGISTER_EXPRESSION(max, ExpressionFromAccumulator<AccumulatorMax>::parse);
+REGISTER_EXPRESSION(min, ExpressionFromAccumulator<AccumulatorMin>::parse);
const char* AccumulatorMinMax::getOpName() const {
if (_sense == 1)
@@ -57,6 +60,9 @@ void AccumulatorMinMax::processInternal(const Value& input, bool merging) {
}
Value AccumulatorMinMax::getValue(bool toBeMerged) const {
+ if (_val.missing()) {
+ return Value(BSONNULL);
+ }
return _val;
}
@@ -69,11 +75,11 @@ void AccumulatorMinMax::reset() {
_memUsageBytes = sizeof(*this);
}
-intrusive_ptr<Accumulator> AccumulatorMinMax::createMin() {
- return new AccumulatorMinMax(Sense::MIN);
+intrusive_ptr<Accumulator> AccumulatorMin::create() {
+ return new AccumulatorMin();
}
-intrusive_ptr<Accumulator> AccumulatorMinMax::createMax() {
- return new AccumulatorMinMax(Sense::MAX);
+intrusive_ptr<Accumulator> AccumulatorMax::create() {
+ return new AccumulatorMax();
}
}
diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp
index 00922345ed9..6b7e757cac7 100644
--- a/src/mongo/db/pipeline/accumulator_std_dev.cpp
+++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp
@@ -30,14 +30,17 @@
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
namespace mongo {
using boost::intrusive_ptr;
-REGISTER_ACCUMULATOR(stdDevPop, AccumulatorStdDev::createPop);
-REGISTER_ACCUMULATOR(stdDevSamp, AccumulatorStdDev::createSamp);
+REGISTER_ACCUMULATOR(stdDevPop, AccumulatorStdDevPop::create);
+REGISTER_ACCUMULATOR(stdDevSamp, AccumulatorStdDevSamp::create);
+REGISTER_EXPRESSION(stdDevPop, ExpressionFromAccumulator<AccumulatorStdDevPop>::parse);
+REGISTER_EXPRESSION(stdDevSamp, ExpressionFromAccumulator<AccumulatorStdDevSamp>::parse);
const char* AccumulatorStdDev::getOpName() const {
return (_isSamp ? "$stdDevSamp" : "$stdDevPop");
@@ -90,12 +93,12 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) const {
}
}
-intrusive_ptr<Accumulator> AccumulatorStdDev::createSamp() {
- return new AccumulatorStdDev(true);
+intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create() {
+ return new AccumulatorStdDevSamp();
}
-intrusive_ptr<Accumulator> AccumulatorStdDev::createPop() {
- return new AccumulatorStdDev(false);
+intrusive_ptr<Accumulator> AccumulatorStdDevPop::create() {
+ return new AccumulatorStdDevPop();
}
AccumulatorStdDev::AccumulatorStdDev(bool isSamp) : _isSamp(isSamp), _count(0), _mean(0), _m2(0) {
diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp
index 5b105753625..c064fe52f04 100644
--- a/src/mongo/db/pipeline/accumulator_sum.cpp
+++ b/src/mongo/db/pipeline/accumulator_sum.cpp
@@ -29,6 +29,7 @@
#include "mongo/platform/basic.h"
#include "mongo/db/pipeline/accumulator.h"
+#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/value.h"
namespace mongo {
@@ -36,6 +37,7 @@ namespace mongo {
using boost::intrusive_ptr;
REGISTER_ACCUMULATOR(sum, AccumulatorSum::create);
+REGISTER_EXPRESSION(sum, ExpressionFromAccumulator<AccumulatorSum>::parse);
const char* AccumulatorSum::getOpName() const {
return "$sum";
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index 9bef088316d..26188490239 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -469,7 +469,7 @@ namespace Min {
class Base : public AccumulatorTests::Base {
protected:
void createAccumulator() {
- _accumulator = AccumulatorMinMax::createMin();
+ _accumulator = AccumulatorMin::create();
ASSERT_EQUALS(string("$min"), _accumulator->getOpName());
}
Accumulator* accumulator() {
@@ -485,8 +485,8 @@ class None : public Base {
public:
void run() {
createAccumulator();
- // The accumulator returns no value in this case.
- ASSERT(accumulator()->getValue(false).missing());
+ // The accumulator returns null in this case.
+ ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false));
}
};
@@ -506,7 +506,7 @@ public:
void run() {
createAccumulator();
accumulator()->process(Value(), false);
- ASSERT_EQUALS(EOO, accumulator()->getValue(false).getType());
+ ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false));
}
};
@@ -539,7 +539,7 @@ namespace Max {
class Base : public AccumulatorTests::Base {
protected:
void createAccumulator() {
- _accumulator = AccumulatorMinMax::createMax();
+ _accumulator = AccumulatorMax::create();
ASSERT_EQUALS(string("$max"), _accumulator->getOpName());
}
Accumulator* accumulator() {
@@ -555,8 +555,8 @@ class None : public Base {
public:
void run() {
createAccumulator();
- // The accumulator returns no value in this case.
- ASSERT(accumulator()->getValue(false).missing());
+ // The accumulator returns null in this case.
+ ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false));
}
};
@@ -576,7 +576,7 @@ public:
void run() {
createAccumulator();
accumulator()->process(Value(), false);
- ASSERT_EQUALS(EOO, accumulator()->getValue(false).getType());
+ ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false));
}
};
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 495410488b6..a91f475cd1f 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -2184,12 +2184,12 @@ intrusive_ptr<Expression> ExpressionNary::optimize() {
if (dynamic_cast<ExpressionConstant*>(expr.get())) {
constExprs.push_back(expr);
} else {
- // If the child operand is the same type as this, then we can
- // extract its operands and inline them here because we know
- // this is commutative and associative. We detect sameness of
- // the child operator by checking for equality of the opNames
+ // If the child operand is the same type as this and is also associative and
+ // commutative, then we can extract its operands and inline them here. We detect
+ // sameness of the child operator by checking for equality of the opNames
ExpressionNary* nary = dynamic_cast<ExpressionNary*>(expr.get());
- if (!nary || !str::equals(nary->getOpName(), getOpName())) {
+ if (!nary || !str::equals(nary->getOpName(), getOpName()) ||
+ !nary->isAssociativeAndCommutative()) {
nonConstExprs.push_back(expr);
} else {
// same expression, so flatten by adding to vpOperand which
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 00dc0dcae02..571e7441240 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -412,6 +412,50 @@ public:
}
};
+/**
+ * Used to make Accumulators available as Expressions, e.g., to make $sum available as an Expression
+ * use "REGISTER_EXPRESSION(sum, ExpressionAccumulator<AccumulatorSum>::parse);".
+ */
+template <typename Accumulator>
+class ExpressionFromAccumulator
+ : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> {
+public:
+ Value evaluateInternal(Variables* vars) const final {
+ Accumulator accum;
+ const size_t n = this->vpOperand.size();
+ // If a single array arg is given, loop through it passing each member to the accumulator.
+ // If a single, non-array arg is given, pass it directly to the accumulator.
+ if (n == 1) {
+ Value singleVal = this->vpOperand[0]->evaluateInternal(vars);
+ if (singleVal.getType() == Array) {
+ for (const Value& val : singleVal.getArray()) {
+ accum.process(val, false);
+ }
+ } else {
+ accum.process(singleVal, false);
+ }
+ } else {
+ // If multiple arguments are given, pass all arguments to the accumulator.
+ for (auto&& argument : this->vpOperand) {
+ accum.process(argument->evaluateInternal(vars), false);
+ }
+ }
+ return accum.getValue(false);
+ }
+
+ bool isAssociativeAndCommutative() const final {
+ // Return false if a single argument is given to avoid a single array argument being treated
+ // as an array instead of as a list of arguments.
+ if (this->vpOperand.size() == 1) {
+ return false;
+ }
+ return Accumulator().isAssociativeAndCommutative();
+ }
+
+ const char* getOpName() const final {
+ return Accumulator().getOpName();
+ }
+};
/**
* Inherit from this class if your expression takes exactly one numeric argument.
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index de278470c06..cfe9361ecc5 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -28,6 +28,7 @@
#include "mongo/platform/basic.h"
+#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/document.h"
#include "mongo/db/pipeline/expression.h"
#include "mongo/dbtests/dbtests.h"
@@ -41,6 +42,24 @@ using std::set;
using std::string;
using std::vector;
+/**
+ * Takes the name of an expression as its first argument and a list of pairs of arguments and
+ * expected results as its second argument, and asserts that for the given expression the arguments
+ * evaluate to the expected results.
+ */
+static void assertExpectedResults(
+ std::string expression,
+ std::initializer_list<std::pair<std::vector<Value>, Value>> operations) {
+ for (auto&& op : operations) {
+ VariablesIdGenerator idGenerator;
+ VariablesParseState vps(&idGenerator);
+ const BSONObj obj = BSON(expression << Value(op.first));
+ Value result = Expression::parseExpression(obj.firstElement(), vps)->evaluate(Document());
+ ASSERT_EQUALS(op.second, result);
+ ASSERT_EQUALS(op.second.getType(), result.getType());
+ }
+}
+
/** Convert BSONObj to a BSONObj with our $const wrappings. */
static BSONObj constify(const BSONObj& obj, bool parentIsArray = false) {
BSONObjBuilder bob;
@@ -1388,6 +1407,72 @@ private:
} // namespace Constant
+TEST(ExpressionFromAccumulators, Avg) {
+ assertExpectedResults("$avg",
+ {// $avg ignores non-numeric inputs.
+ {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(3.0)},
+ // $avg always returns a double.
+ {{Value(10LL), Value(20LL)}, Value(15.0)},
+ // $avg returns null when no arguments are provided.
+ {{}, Value(BSONNULL)}});
+}
+
+TEST(ExpressionFromAccumulators, Max) {
+ assertExpectedResults("$max",
+ {// $max treats non-numeric inputs as valid arguments.
+ {{Value(1), Value(BSONNULL), Value(), Value("a")}, Value("a")},
+ {{Value("a"), Value("b")}, Value("b")},
+ // $max always preserves the type of the result.
+ {{Value(10LL), Value(0.0), Value(5)}, Value(10LL)},
+ // $max returns null when no arguments are provided.
+ {{}, Value(BSONNULL)}});
+}
+
+TEST(ExpressionFromAccumulators, Min) {
+ assertExpectedResults("$min",
+ {// $min treats non-numeric inputs as valid arguments.
+ {{Value("string")}, Value("string")},
+ {{Value(1), Value(BSONNULL), Value(), Value("a")}, Value(1)},
+ {{Value("a"), Value("b")}, Value("a")},
+ // $min always preserves the type of the result.
+ {{Value(0LL), Value(20.0), Value(10)}, Value(0LL)},
+ // $min returns null when no arguments are provided.
+ {{}, Value(BSONNULL)}});
+}
+
+TEST(ExpressionFromAccumulators, Sum) {
+ assertExpectedResults(
+ "$sum",
+ {// $sum ignores non-numeric inputs.
+ {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(3)},
+ // If any argument is a double, $sum returns a double
+ {{Value(10LL), Value(10.0)}, Value(20.0)},
+ // If no arguments are doubles and an argument is a long, $sum returns a long
+ {{Value(10LL), Value(10)}, Value(20LL)},
+ // $sum returns 0 when no arguments are provided.
+ {{}, Value(0)}});
+}
+
+TEST(ExpressionFromAccumulators, StdDevPop) {
+ assertExpectedResults("$stdDevPop",
+ {// $stdDevPop ignores non-numeric inputs.
+ {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(0.0)},
+ // $stdDevPop always returns a double.
+ {{Value(1LL), Value(3LL)}, Value(1.0)},
+ // $stdDevPop returns null when no arguments are provided.
+ {{}, Value(BSONNULL)}});
+}
+
+TEST(ExpressionFromAccumulators, StdDevSamp) {
+ assertExpectedResults("$stdDevSamp",
+ {// $stdDevSamp ignores non-numeric inputs.
+ {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(BSONNULL)},
+ // $stdDevSamp always returns a double.
+ {{Value(1LL), Value(2LL), Value(3LL)}, Value(1.0)},
+ // $stdDevSamp returns null when no arguments are provided.
+ {{}, Value(BSONNULL)}});
+}
+
namespace FieldPath {
/** The provided field path does not pass validation. */