diff options
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/db/pipeline/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator.h | 48 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_avg.cpp | 2 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_min_max.cpp | 18 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_std_dev.cpp | 15 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_sum.cpp | 2 | ||||
-rw-r--r-- | src/mongo/db/pipeline/accumulator_test.cpp | 16 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 10 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 44 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression_test.cpp | 85 |
10 files changed, 210 insertions, 32 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 555aaaeceb2..f849620a5a9 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -87,6 +87,7 @@ env.Library( ], LIBDEPS=[ 'document_value', + 'expression', 'field_path', ] ) @@ -141,6 +142,7 @@ env.CppUnitTest( target='agg_expression_test', source='expression_test.cpp', LIBDEPS=[ + 'accumulator', 'expression', ], ) diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index 64e8aed036c..77c3ecf63fc 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -99,6 +99,10 @@ public: */ static Factory getFactory(StringData name); + virtual bool isAssociativeAndCommutative() const { + return false; + } + protected: /// Update subclass's internal state based on input virtual void processInternal(const Value& input, bool merging) = 0; @@ -119,6 +123,10 @@ public: static boost::intrusive_ptr<Accumulator> create(); + bool isAssociativeAndCommutative() const final { + return true; + } + private: typedef std::unordered_set<Value, Value::Hash> SetType; SetType set; @@ -169,6 +177,10 @@ public: static boost::intrusive_ptr<Accumulator> create(); + bool isAssociativeAndCommutative() const final { + return true; + } + private: BSONType totalType; long long longTotal; @@ -176,7 +188,7 @@ private: }; -class AccumulatorMinMax final : public Accumulator { +class AccumulatorMinMax : public Accumulator { public: enum Sense : int { MIN = 1, @@ -190,14 +202,27 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> createMin(); - static boost::intrusive_ptr<Accumulator> createMax(); + bool isAssociativeAndCommutative() const final { + return true; + } private: Value _val; const Sense _sense; }; +class AccumulatorMax final : public AccumulatorMinMax { +public: + AccumulatorMax() : AccumulatorMinMax(MAX) {} + static boost::intrusive_ptr<Accumulator> create(); +}; + +class AccumulatorMin final : public AccumulatorMinMax { +public: + AccumulatorMin() : AccumulatorMinMax(MIN) {} + static boost::intrusive_ptr<Accumulator> create(); +}; + class AccumulatorPush final : public Accumulator { public: @@ -232,7 +257,7 @@ private: }; -class AccumulatorStdDev final : public Accumulator { +class AccumulatorStdDev : public Accumulator { public: explicit AccumulatorStdDev(bool isSamp); @@ -241,13 +266,22 @@ public: const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> createSamp(); - static boost::intrusive_ptr<Accumulator> createPop(); - private: const bool _isSamp; long long _count; double _mean; double _m2; // Running sum of squares of delta from mean. Named to match algorithm. }; + +class AccumulatorStdDevPop final : public AccumulatorStdDev { +public: + AccumulatorStdDevPop() : AccumulatorStdDev(false) {} + static boost::intrusive_ptr<Accumulator> create(); +}; + +class AccumulatorStdDevSamp final : public AccumulatorStdDev { +public: + AccumulatorStdDevSamp() : AccumulatorStdDev(true) {} + static boost::intrusive_ptr<Accumulator> create(); +}; } diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index 06dd585bdf7..ed11d81ecc0 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -30,6 +30,7 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" @@ -38,6 +39,7 @@ namespace mongo { using boost::intrusive_ptr; REGISTER_ACCUMULATOR(avg, AccumulatorAvg::create); +REGISTER_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse); const char* AccumulatorAvg::getOpName() const { return "$avg"; diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index 9bfbf7b380b..fc854ca1ddb 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -29,14 +29,17 @@ #include "mongo/platform/basic.h" #include "mongo/db/pipeline/accumulator.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/value.h" namespace mongo { using boost::intrusive_ptr; -REGISTER_ACCUMULATOR(max, AccumulatorMinMax::createMax); -REGISTER_ACCUMULATOR(min, AccumulatorMinMax::createMin); +REGISTER_ACCUMULATOR(max, AccumulatorMax::create); +REGISTER_ACCUMULATOR(min, AccumulatorMin::create); +REGISTER_EXPRESSION(max, ExpressionFromAccumulator<AccumulatorMax>::parse); +REGISTER_EXPRESSION(min, ExpressionFromAccumulator<AccumulatorMin>::parse); const char* AccumulatorMinMax::getOpName() const { if (_sense == 1) @@ -57,6 +60,9 @@ void AccumulatorMinMax::processInternal(const Value& input, bool merging) { } Value AccumulatorMinMax::getValue(bool toBeMerged) const { + if (_val.missing()) { + return Value(BSONNULL); + } return _val; } @@ -69,11 +75,11 @@ void AccumulatorMinMax::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorMinMax::createMin() { - return new AccumulatorMinMax(Sense::MIN); +intrusive_ptr<Accumulator> AccumulatorMin::create() { + return new AccumulatorMin(); } -intrusive_ptr<Accumulator> AccumulatorMinMax::createMax() { - return new AccumulatorMinMax(Sense::MAX); +intrusive_ptr<Accumulator> AccumulatorMax::create() { + return new AccumulatorMax(); } } diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 00922345ed9..6b7e757cac7 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -30,14 +30,17 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" namespace mongo { using boost::intrusive_ptr; -REGISTER_ACCUMULATOR(stdDevPop, AccumulatorStdDev::createPop); -REGISTER_ACCUMULATOR(stdDevSamp, AccumulatorStdDev::createSamp); +REGISTER_ACCUMULATOR(stdDevPop, AccumulatorStdDevPop::create); +REGISTER_ACCUMULATOR(stdDevSamp, AccumulatorStdDevSamp::create); +REGISTER_EXPRESSION(stdDevPop, ExpressionFromAccumulator<AccumulatorStdDevPop>::parse); +REGISTER_EXPRESSION(stdDevSamp, ExpressionFromAccumulator<AccumulatorStdDevSamp>::parse); const char* AccumulatorStdDev::getOpName() const { return (_isSamp ? "$stdDevSamp" : "$stdDevPop"); @@ -90,12 +93,12 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) const { } } -intrusive_ptr<Accumulator> AccumulatorStdDev::createSamp() { - return new AccumulatorStdDev(true); +intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create() { + return new AccumulatorStdDevSamp(); } -intrusive_ptr<Accumulator> AccumulatorStdDev::createPop() { - return new AccumulatorStdDev(false); +intrusive_ptr<Accumulator> AccumulatorStdDevPop::create() { + return new AccumulatorStdDevPop(); } AccumulatorStdDev::AccumulatorStdDev(bool isSamp) : _isSamp(isSamp), _count(0), _mean(0), _m2(0) { diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 5b105753625..c064fe52f04 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -29,6 +29,7 @@ #include "mongo/platform/basic.h" #include "mongo/db/pipeline/accumulator.h" +#include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/value.h" namespace mongo { @@ -36,6 +37,7 @@ namespace mongo { using boost::intrusive_ptr; REGISTER_ACCUMULATOR(sum, AccumulatorSum::create); +REGISTER_EXPRESSION(sum, ExpressionFromAccumulator<AccumulatorSum>::parse); const char* AccumulatorSum::getOpName() const { return "$sum"; diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 9bef088316d..26188490239 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -469,7 +469,7 @@ namespace Min { class Base : public AccumulatorTests::Base { protected: void createAccumulator() { - _accumulator = AccumulatorMinMax::createMin(); + _accumulator = AccumulatorMin::create(); ASSERT_EQUALS(string("$min"), _accumulator->getOpName()); } Accumulator* accumulator() { @@ -485,8 +485,8 @@ class None : public Base { public: void run() { createAccumulator(); - // The accumulator returns no value in this case. - ASSERT(accumulator()->getValue(false).missing()); + // The accumulator returns null in this case. + ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false)); } }; @@ -506,7 +506,7 @@ public: void run() { createAccumulator(); accumulator()->process(Value(), false); - ASSERT_EQUALS(EOO, accumulator()->getValue(false).getType()); + ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false)); } }; @@ -539,7 +539,7 @@ namespace Max { class Base : public AccumulatorTests::Base { protected: void createAccumulator() { - _accumulator = AccumulatorMinMax::createMax(); + _accumulator = AccumulatorMax::create(); ASSERT_EQUALS(string("$max"), _accumulator->getOpName()); } Accumulator* accumulator() { @@ -555,8 +555,8 @@ class None : public Base { public: void run() { createAccumulator(); - // The accumulator returns no value in this case. - ASSERT(accumulator()->getValue(false).missing()); + // The accumulator returns null in this case. + ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false)); } }; @@ -576,7 +576,7 @@ public: void run() { createAccumulator(); accumulator()->process(Value(), false); - ASSERT_EQUALS(EOO, accumulator()->getValue(false).getType()); + ASSERT_EQUALS(Value(BSONNULL), accumulator()->getValue(false)); } }; diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 495410488b6..a91f475cd1f 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2184,12 +2184,12 @@ intrusive_ptr<Expression> ExpressionNary::optimize() { if (dynamic_cast<ExpressionConstant*>(expr.get())) { constExprs.push_back(expr); } else { - // If the child operand is the same type as this, then we can - // extract its operands and inline them here because we know - // this is commutative and associative. We detect sameness of - // the child operator by checking for equality of the opNames + // If the child operand is the same type as this and is also associative and + // commutative, then we can extract its operands and inline them here. We detect + // sameness of the child operator by checking for equality of the opNames ExpressionNary* nary = dynamic_cast<ExpressionNary*>(expr.get()); - if (!nary || !str::equals(nary->getOpName(), getOpName())) { + if (!nary || !str::equals(nary->getOpName(), getOpName()) || + !nary->isAssociativeAndCommutative()) { nonConstExprs.push_back(expr); } else { // same expression, so flatten by adding to vpOperand which diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 00dc0dcae02..571e7441240 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -412,6 +412,50 @@ public: } }; +/** + * Used to make Accumulators available as Expressions, e.g., to make $sum available as an Expression + * use "REGISTER_EXPRESSION(sum, ExpressionAccumulator<AccumulatorSum>::parse);". + */ +template <typename Accumulator> +class ExpressionFromAccumulator + : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> { +public: + Value evaluateInternal(Variables* vars) const final { + Accumulator accum; + const size_t n = this->vpOperand.size(); + // If a single array arg is given, loop through it passing each member to the accumulator. + // If a single, non-array arg is given, pass it directly to the accumulator. + if (n == 1) { + Value singleVal = this->vpOperand[0]->evaluateInternal(vars); + if (singleVal.getType() == Array) { + for (const Value& val : singleVal.getArray()) { + accum.process(val, false); + } + } else { + accum.process(singleVal, false); + } + } else { + // If multiple arguments are given, pass all arguments to the accumulator. + for (auto&& argument : this->vpOperand) { + accum.process(argument->evaluateInternal(vars), false); + } + } + return accum.getValue(false); + } + + bool isAssociativeAndCommutative() const final { + // Return false if a single argument is given to avoid a single array argument being treated + // as an array instead of as a list of arguments. + if (this->vpOperand.size() == 1) { + return false; + } + return Accumulator().isAssociativeAndCommutative(); + } + + const char* getOpName() const final { + return Accumulator().getOpName(); + } +}; /** * Inherit from this class if your expression takes exactly one numeric argument. diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index de278470c06..cfe9361ecc5 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -28,6 +28,7 @@ #include "mongo/platform/basic.h" +#include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/expression.h" #include "mongo/dbtests/dbtests.h" @@ -41,6 +42,24 @@ using std::set; using std::string; using std::vector; +/** + * Takes the name of an expression as its first argument and a list of pairs of arguments and + * expected results as its second argument, and asserts that for the given expression the arguments + * evaluate to the expected results. + */ +static void assertExpectedResults( + std::string expression, + std::initializer_list<std::pair<std::vector<Value>, Value>> operations) { + for (auto&& op : operations) { + VariablesIdGenerator idGenerator; + VariablesParseState vps(&idGenerator); + const BSONObj obj = BSON(expression << Value(op.first)); + Value result = Expression::parseExpression(obj.firstElement(), vps)->evaluate(Document()); + ASSERT_EQUALS(op.second, result); + ASSERT_EQUALS(op.second.getType(), result.getType()); + } +} + /** Convert BSONObj to a BSONObj with our $const wrappings. */ static BSONObj constify(const BSONObj& obj, bool parentIsArray = false) { BSONObjBuilder bob; @@ -1388,6 +1407,72 @@ private: } // namespace Constant +TEST(ExpressionFromAccumulators, Avg) { + assertExpectedResults("$avg", + {// $avg ignores non-numeric inputs. + {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(3.0)}, + // $avg always returns a double. + {{Value(10LL), Value(20LL)}, Value(15.0)}, + // $avg returns null when no arguments are provided. + {{}, Value(BSONNULL)}}); +} + +TEST(ExpressionFromAccumulators, Max) { + assertExpectedResults("$max", + {// $max treats non-numeric inputs as valid arguments. + {{Value(1), Value(BSONNULL), Value(), Value("a")}, Value("a")}, + {{Value("a"), Value("b")}, Value("b")}, + // $max always preserves the type of the result. + {{Value(10LL), Value(0.0), Value(5)}, Value(10LL)}, + // $max returns null when no arguments are provided. + {{}, Value(BSONNULL)}}); +} + +TEST(ExpressionFromAccumulators, Min) { + assertExpectedResults("$min", + {// $min treats non-numeric inputs as valid arguments. + {{Value("string")}, Value("string")}, + {{Value(1), Value(BSONNULL), Value(), Value("a")}, Value(1)}, + {{Value("a"), Value("b")}, Value("a")}, + // $min always preserves the type of the result. + {{Value(0LL), Value(20.0), Value(10)}, Value(0LL)}, + // $min returns null when no arguments are provided. + {{}, Value(BSONNULL)}}); +} + +TEST(ExpressionFromAccumulators, Sum) { + assertExpectedResults( + "$sum", + {// $sum ignores non-numeric inputs. + {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(3)}, + // If any argument is a double, $sum returns a double + {{Value(10LL), Value(10.0)}, Value(20.0)}, + // If no arguments are doubles and an argument is a long, $sum returns a long + {{Value(10LL), Value(10)}, Value(20LL)}, + // $sum returns 0 when no arguments are provided. + {{}, Value(0)}}); +} + +TEST(ExpressionFromAccumulators, StdDevPop) { + assertExpectedResults("$stdDevPop", + {// $stdDevPop ignores non-numeric inputs. + {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(0.0)}, + // $stdDevPop always returns a double. + {{Value(1LL), Value(3LL)}, Value(1.0)}, + // $stdDevPop returns null when no arguments are provided. + {{}, Value(BSONNULL)}}); +} + +TEST(ExpressionFromAccumulators, StdDevSamp) { + assertExpectedResults("$stdDevSamp", + {// $stdDevSamp ignores non-numeric inputs. + {{Value("string"), Value(BSONNULL), Value(), Value(3)}, Value(BSONNULL)}, + // $stdDevSamp always returns a double. + {{Value(1LL), Value(2LL), Value(3LL)}, Value(1.0)}, + // $stdDevSamp returns null when no arguments are provided. + {{}, Value(BSONNULL)}}); +} + namespace FieldPath { /** The provided field path does not pass validation. */ |