diff options
Diffstat (limited to 'src/mongo/db')
95 files changed, 1888 insertions, 1392 deletions
diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp index 7fe8d54992d..bb426950b44 100644 --- a/src/mongo/db/commands/pipeline_command.cpp +++ b/src/mongo/db/commands/pipeline_command.cpp @@ -258,7 +258,6 @@ boost::intrusive_ptr<Pipeline> reparsePipeline( fassertFailedWithStatusNoTrace(40175, reparsedPipeline.getStatus()); } - reparsedPipeline.getValue()->injectExpressionContext(expCtx); reparsedPipeline.getValue()->optimizePipeline(); return reparsedPipeline.getValue(); } @@ -350,18 +349,15 @@ public: // For operations on views, this will be the underlying namespace. const NamespaceString& nss = request.getNamespaceString(); - // Set up the ExpressionContext. - intrusive_ptr<ExpressionContext> expCtx = new ExpressionContext(txn, request); - expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp"; - - auto resolvedNamespaces = resolveInvolvedNamespaces(txn, request); - if (!resolvedNamespaces.isOK()) { - return appendCommandStatus(result, resolvedNamespaces.getStatus()); - } - expCtx->resolvedNamespaces = std::move(resolvedNamespaces.getValue()); + // Parse the user-specified collation, if any. + std::unique_ptr<CollatorInterface> userSpecifiedCollator = request.getCollation().isEmpty() + ? nullptr + : uassertStatusOK(CollatorFactoryInterface::get(txn->getServiceContext()) + ->makeFromBSON(request.getCollation())); boost::optional<ClientCursorPin> pin; // either this OR the exec will be non-null unique_ptr<PlanExecutor> exec; + boost::intrusive_ptr<ExpressionContext> expCtx; boost::intrusive_ptr<Pipeline> pipeline; auto curOp = CurOp::get(txn); { @@ -387,7 +383,7 @@ public: // means that no collation was specified. if (!request.getCollation().isEmpty()) { if (!CollatorInterface::collatorsMatch(ctx.getView()->defaultCollator(), - expCtx->getCollator())) { + userSpecifiedCollator.get())) { return appendCommandStatus(result, {ErrorCodes::OptionNotSupportedOnView, "Cannot override a view's default collation"}); @@ -440,14 +436,26 @@ public: return status; } + // Determine the appropriate collation to make the ExpressionContext. + // If the pipeline does not have a user-specified collation, set it from the collection - // default. + // default. Be careful to consult the original request BSON to check if a collation was + // specified, since a specification of {locale: "simple"} will result in a null + // collator. + auto collatorToUse = std::move(userSpecifiedCollator); if (request.getCollation().isEmpty() && collection && collection->getDefaultCollator()) { - invariant(!expCtx->getCollator()); - expCtx->setCollator(collection->getDefaultCollator()->clone()); + invariant(!collatorToUse); + collatorToUse = collection->getDefaultCollator()->clone(); } + expCtx.reset( + new ExpressionContext(txn, + request, + std::move(collatorToUse), + uassertStatusOK(resolveInvolvedNamespaces(txn, request)))); + expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp"; + // Parse the pipeline. auto statusWithPipeline = Pipeline::parse(request.getPipeline(), expCtx); if (!statusWithPipeline.isOK()) { @@ -463,14 +471,6 @@ public: return appendCommandStatus(result, pipelineCollationStatus); } - // Propagate the ExpressionContext throughout all of the pipeline's stages and - // expressions. - pipeline->injectExpressionContext(expCtx); - - // The pipeline must be optimized after the correct collator has been set on it (by - // injecting the ExpressionContext containing the collator). This is necessary because - // optimization may make string comparisons, e.g. optimizing {$eq: [<str1>, <str2>]} to - // a constant. pipeline->optimizePipeline(); if (kDebugBuild && !expCtx->isExplain && !expCtx->inShard) { diff --git a/src/mongo/db/pipeline/accumulation_statement.cpp b/src/mongo/db/pipeline/accumulation_statement.cpp index b191a78fdf1..9ac394b0018 100644 --- a/src/mongo/db/pipeline/accumulation_statement.cpp +++ b/src/mongo/db/pipeline/accumulation_statement.cpp @@ -64,7 +64,9 @@ Accumulator::Factory AccumulationStatement::getFactory(StringData name) { } AccumulationStatement AccumulationStatement::parseAccumulationStatement( - const BSONElement& elem, const VariablesParseState& vps) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& elem, + const VariablesParseState& vps) { auto fieldName = elem.fieldNameStringData(); uassert(40234, str::stream() << "The field '" << fieldName << "' must be an accumulator object", @@ -91,7 +93,7 @@ AccumulationStatement AccumulationStatement::parseAccumulationStatement( return {fieldName.toString(), AccumulationStatement::getFactory(accName), - Expression::parseOperand(specElem, vps)}; + Expression::parseOperand(expCtx, specElem, vps)}; } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h index 5a7c578ff65..92f7906ed9e 100644 --- a/src/mongo/db/pipeline/accumulation_statement.h +++ b/src/mongo/db/pipeline/accumulation_statement.h @@ -72,8 +72,10 @@ public: * * Throws a UserException if parsing fails. */ - static AccumulationStatement parseAccumulationStatement(const BSONElement& elem, - const VariablesParseState& vps); + static AccumulationStatement parseAccumulationStatement( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& elem, + const VariablesParseState& vps); /** * Registers an Accumulator with a parsing function, so that when an accumulator with the given diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index e0a2dbdb11a..28645e90a9b 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -48,9 +48,10 @@ namespace mongo { class Accumulator : public RefCountable { public: - using Factory = boost::intrusive_ptr<Accumulator> (*)(); + using Factory = boost::intrusive_ptr<Accumulator> (*)( + const boost::intrusive_ptr<ExpressionContext>& expCtx); - Accumulator() = default; + Accumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} /** Process input and update internal state. * merging should be true when processing outputs from getValue(true). @@ -84,26 +85,10 @@ public: return false; } - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Accumulator. - * Construction of accumulators is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: /// Update subclass's internal state based on input virtual void processInternal(const Value& input, bool merging) = 0; - /** - * Accumulators which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} - const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { return _expCtx; } @@ -118,14 +103,15 @@ private: class AccumulatorAddToSet final : public Accumulator { public: - AccumulatorAddToSet(); + explicit AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { return true; @@ -135,26 +121,22 @@ public: return true; } - void doInjectExpressionContext() final; - private: - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected, since this set must use the comparator's definition of - // equality. - boost::optional<ValueUnorderedSet> _set; + ValueUnorderedSet _set; }; class AccumulatorFirst final : public Accumulator { public: - AccumulatorFirst(); + explicit AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: bool _haveFirst; @@ -164,14 +146,15 @@ private: class AccumulatorLast final : public Accumulator { public: - AccumulatorLast(); + explicit AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: Value _last; @@ -180,14 +163,15 @@ private: class AccumulatorSum final : public Accumulator { public: - AccumulatorSum(); + explicit AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { return true; @@ -211,7 +195,7 @@ public: MAX = -1, // Used to "scale" comparison. }; - explicit AccumulatorMinMax(Sense sense); + AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx, Sense sense); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -233,27 +217,32 @@ private: class AccumulatorMax final : public AccumulatorMinMax { public: - AccumulatorMax() : AccumulatorMinMax(MAX) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorMax(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorMinMax(expCtx, MAX) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorMin final : public AccumulatorMinMax { public: - AccumulatorMin() : AccumulatorMinMax(MIN) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorMin(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorMinMax(expCtx, MIN) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorPush final : public Accumulator { public: - AccumulatorPush(); + explicit AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: std::vector<Value> vpValue; @@ -262,14 +251,15 @@ private: class AccumulatorAvg final : public Accumulator { public: - AccumulatorAvg(); + explicit AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: /** @@ -287,7 +277,7 @@ private: class AccumulatorStdDev : public Accumulator { public: - explicit AccumulatorStdDev(bool isSamp); + AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, bool isSamp); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -303,13 +293,17 @@ private: class AccumulatorStdDevPop final : public AccumulatorStdDev { public: - AccumulatorStdDevPop() : AccumulatorStdDev(false) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorStdDevPop(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorStdDev(expCtx, false) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorStdDevSamp final : public AccumulatorStdDev { public: - AccumulatorStdDevSamp() : AccumulatorStdDev(true) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorStdDevSamp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorStdDev(expCtx, true) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; } diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index 6a54e846a85..00845b32fc8 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -48,7 +48,7 @@ const char* AccumulatorAddToSet::getOpName() const { void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { if (!merging) { if (!input.missing()) { - bool inserted = _set->insert(input).second; + bool inserted = _set.insert(input).second; if (inserted) { _memUsageBytes += input.getApproximateSize(); } @@ -62,7 +62,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { const vector<Value>& array = input.getArray(); for (size_t i = 0; i < array.size(); i++) { - bool inserted = _set->insert(array[i]).second; + bool inserted = _set.insert(array[i]).second; if (inserted) { _memUsageBytes += array[i].getApproximateSize(); } @@ -71,10 +71,11 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { } Value AccumulatorAddToSet::getValue(bool toBeMerged) const { - return Value(vector<Value>(_set->begin(), _set->end())); + return Value(vector<Value>(_set.begin(), _set.end())); } -AccumulatorAddToSet::AccumulatorAddToSet() { +AccumulatorAddToSet::AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _set(expCtx->getValueComparator().makeUnorderedValueSet()) { _memUsageBytes = sizeof(*this); } @@ -83,12 +84,9 @@ void AccumulatorAddToSet::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorAddToSet::create() { - return new AccumulatorAddToSet(); -} - -void AccumulatorAddToSet::doInjectExpressionContext() { - _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet(); +intrusive_ptr<Accumulator> AccumulatorAddToSet::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorAddToSet(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index c977ba99c15..da0c08fce9c 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -92,8 +92,9 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) { _count++; } -intrusive_ptr<Accumulator> AccumulatorAvg::create() { - return new AccumulatorAvg(); +intrusive_ptr<Accumulator> AccumulatorAvg::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorAvg(expCtx); } Decimal128 AccumulatorAvg::_getDecimalTotal() const { @@ -120,7 +121,8 @@ Value AccumulatorAvg::getValue(bool toBeMerged) const { return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count)); } -AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) { +AccumulatorAvg::AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _isDecimal(false), _count(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_first.cpp b/src/mongo/db/pipeline/accumulator_first.cpp index 68ea563ff42..3092bd0e9da 100644 --- a/src/mongo/db/pipeline/accumulator_first.cpp +++ b/src/mongo/db/pipeline/accumulator_first.cpp @@ -57,7 +57,8 @@ Value AccumulatorFirst::getValue(bool toBeMerged) const { return _first; } -AccumulatorFirst::AccumulatorFirst() : _haveFirst(false) { +AccumulatorFirst::AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _haveFirst(false) { _memUsageBytes = sizeof(*this); } @@ -68,7 +69,8 @@ void AccumulatorFirst::reset() { } -intrusive_ptr<Accumulator> AccumulatorFirst::create() { - return new AccumulatorFirst(); +intrusive_ptr<Accumulator> AccumulatorFirst::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorFirst(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_last.cpp b/src/mongo/db/pipeline/accumulator_last.cpp index 64f545f7db1..5f465b1bf4e 100644 --- a/src/mongo/db/pipeline/accumulator_last.cpp +++ b/src/mongo/db/pipeline/accumulator_last.cpp @@ -53,7 +53,8 @@ Value AccumulatorLast::getValue(bool toBeMerged) const { return _last; } -AccumulatorLast::AccumulatorLast() { +AccumulatorLast::AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -62,7 +63,8 @@ void AccumulatorLast::reset() { _last = Value(); } -intrusive_ptr<Accumulator> AccumulatorLast::create() { - return new AccumulatorLast(); +intrusive_ptr<Accumulator> AccumulatorLast::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorLast(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index 08db8b9e0fc..f5c4b254b88 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -68,7 +68,9 @@ Value AccumulatorMinMax::getValue(bool toBeMerged) const { return _val; } -AccumulatorMinMax::AccumulatorMinMax(Sense sense) : _sense(sense) { +AccumulatorMinMax::AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Sense sense) + : Accumulator(expCtx), _sense(sense) { _memUsageBytes = sizeof(*this); } @@ -77,11 +79,13 @@ void AccumulatorMinMax::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorMin::create() { - return new AccumulatorMin(); +intrusive_ptr<Accumulator> AccumulatorMin::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorMin(expCtx); } -intrusive_ptr<Accumulator> AccumulatorMax::create() { - return new AccumulatorMax(); +intrusive_ptr<Accumulator> AccumulatorMax::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorMax(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp index 16f722a08f3..9239d8b8a18 100644 --- a/src/mongo/db/pipeline/accumulator_push.cpp +++ b/src/mongo/db/pipeline/accumulator_push.cpp @@ -71,7 +71,8 @@ Value AccumulatorPush::getValue(bool toBeMerged) const { return Value(vpValue); } -AccumulatorPush::AccumulatorPush() { +AccumulatorPush::AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -80,7 +81,8 @@ void AccumulatorPush::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorPush::create() { - return new AccumulatorPush(); +intrusive_ptr<Accumulator> AccumulatorPush::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorPush(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 28565250fe3..f45678dce27 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -95,15 +95,19 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) const { } } -intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create() { - return new AccumulatorStdDevSamp(); +intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorStdDevSamp(expCtx); } -intrusive_ptr<Accumulator> AccumulatorStdDevPop::create() { - return new AccumulatorStdDevPop(); +intrusive_ptr<Accumulator> AccumulatorStdDevPop::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorStdDevPop(expCtx); } -AccumulatorStdDev::AccumulatorStdDev(bool isSamp) : _isSamp(isSamp), _count(0), _mean(0), _m2(0) { +AccumulatorStdDev::AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, + bool isSamp) + : Accumulator(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 13aa12fbfca..bbbf596d4a8 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -84,8 +84,9 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) { } } -intrusive_ptr<Accumulator> AccumulatorSum::create() { - return new AccumulatorSum(); +intrusive_ptr<Accumulator> AccumulatorSum::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorSum(expCtx); } Value AccumulatorSum::getValue(bool toBeMerged) const { @@ -133,7 +134,8 @@ Value AccumulatorSum::getValue(bool toBeMerged) const { } } -AccumulatorSum::AccumulatorSum() { +AccumulatorSum::AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { // This is a fixed size Accumulator so we never need to update this. _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 29ab135ebb4..3fcb1d70ad7 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -32,7 +32,7 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/collation/collator_interface_mock.h" #include "mongo/dbtests/dbtests.h" #include "mongo/stdx/memory.h" @@ -57,8 +57,7 @@ static void assertExpectedResults( try { // Asserts that result equals expected result when not sharded. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); for (auto&& val : op.first) { accum->process(val, false); } @@ -69,10 +68,8 @@ static void assertExpectedResults( // Asserts that result equals expected result when all input is on one shard. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); - boost::intrusive_ptr<Accumulator> shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); + boost::intrusive_ptr<Accumulator> shard(factory(expCtx)); for (auto&& val : op.first) { shard->process(val, false); } @@ -84,11 +81,9 @@ static void assertExpectedResults( // Asserts that result equals expected result when each input is on a separate shard. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); for (auto&& val : op.first) { - boost::intrusive_ptr<Accumulator> shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> shard(factory(expCtx)); shard->process(val, false); accum->process(shard->getValue(true), true); } @@ -104,7 +99,7 @@ static void assertExpectedResults( } TEST(Accumulators, Avg) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$avg", expCtx, @@ -160,7 +155,7 @@ TEST(Accumulators, Avg) { } TEST(Accumulators, First) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$first", expCtx, @@ -179,7 +174,7 @@ TEST(Accumulators, First) { } TEST(Accumulators, Last) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$last", expCtx, @@ -198,7 +193,7 @@ TEST(Accumulators, Last) { } TEST(Accumulators, Min) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$min", expCtx, @@ -217,14 +212,14 @@ TEST(Accumulators, Min) { } TEST(Accumulators, MinRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$min", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("cba"_sd)}}); } TEST(Accumulators, Max) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$max", expCtx, @@ -243,14 +238,14 @@ TEST(Accumulators, Max) { } TEST(Accumulators, MaxRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$max", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("abc"_sd)}}); } TEST(Accumulators, Sum) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$sum", expCtx, @@ -340,7 +335,7 @@ TEST(Accumulators, Sum) { } TEST(Accumulators, AddToSetRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kAlwaysEqual)); assertExpectedResults("$addToSet", diff --git a/src/mongo/db/pipeline/aggregation_context_fixture.h b/src/mongo/db/pipeline/aggregation_context_fixture.h index 75353c46e7f..786f50a9b5a 100644 --- a/src/mongo/db/pipeline/aggregation_context_fixture.h +++ b/src/mongo/db/pipeline/aggregation_context_fixture.h @@ -32,7 +32,7 @@ #include <memory> #include "mongo/db/client.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context_noop.h" #include "mongo/stdx/memory.h" @@ -48,16 +48,16 @@ public: AggregationContextFixture() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _expCtx(new ExpressionContext( + _expCtx(new ExpressionContextForTest( _opCtx.get(), AggregationRequest(NamespaceString("unittests.pipeline_test"), {}))) {} - boost::intrusive_ptr<ExpressionContext> getExpCtx() { + boost::intrusive_ptr<ExpressionContextForTest> getExpCtx() { return _expCtx.get(); } private: std::unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - boost::intrusive_ptr<ExpressionContext> _expCtx; + boost::intrusive_ptr<ExpressionContextForTest> _expCtx; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source.cpp b/src/mongo/db/pipeline/document_source.cpp index b83ec851ebe..440dc762175 100644 --- a/src/mongo/db/pipeline/document_source.cpp +++ b/src/mongo/db/pipeline/document_source.cpp @@ -61,7 +61,7 @@ void DocumentSource::registerParser(string name, Parser parser) { } vector<intrusive_ptr<DocumentSource>> DocumentSource::parse( - const intrusive_ptr<ExpressionContext> expCtx, BSONObj stageObj) { + const intrusive_ptr<ExpressionContext>& expCtx, BSONObj stageObj) { uassert(16435, "A pipeline stage specification object must contain exactly one field.", stageObj.nFields() == 1); diff --git a/src/mongo/db/pipeline/document_source.h b/src/mongo/db/pipeline/document_source.h index 81ba3fc61a6..bf4c65fa18b 100644 --- a/src/mongo/db/pipeline/document_source.h +++ b/src/mongo/db/pipeline/document_source.h @@ -269,22 +269,10 @@ public: virtual void reattachToOperationContext(OperationContext* opCtx) {} /** - * Injects a new ExpressionContext into this DocumentSource and propagates the ExpressionContext - * to all child expressions, accumulators, etc. - * - * Stages which require work to propagate the ExpressionContext to their private execution - * machinery should override doInjectExpressionContext(). - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - pExpCtx = expCtx; - doInjectExpressionContext(); - } - - /** * Create a DocumentSource pipeline stage from 'stageObj'. */ static std::vector<boost::intrusive_ptr<DocumentSource>> parse( - const boost::intrusive_ptr<ExpressionContext> expCtx, BSONObj stageObj); + const boost::intrusive_ptr<ExpressionContext>& expCtx, BSONObj stageObj); /** * Registers a DocumentSource with a parsing function, so that when a stage with the given name @@ -444,15 +432,6 @@ protected: explicit DocumentSource(const boost::intrusive_ptr<ExpressionContext>& pExpCtx); /** - * DocumentSources which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - * - * Any stage subclassing from DocumentSource should override this method if it contains - * expressions or accumulators which need to attach to the newly injected ExpressionContext. - */ - virtual void doInjectExpressionContext() {} - - /** * Attempt to perform an optimization with the following source in the pipeline. 'container' * refers to the entire pipeline, and 'itr' points to this stage within the pipeline. The caller * must guarantee that std::next(itr) != container->end(). diff --git a/src/mongo/db/pipeline/document_source_add_fields.cpp b/src/mongo/db/pipeline/document_source_add_fields.cpp index 0f7366538f9..445feaa1a51 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.cpp +++ b/src/mongo/db/pipeline/document_source_add_fields.cpp @@ -49,8 +49,7 @@ intrusive_ptr<DocumentSource> DocumentSourceAddFields::create( BSONObj addFieldsSpec, const intrusive_ptr<ExpressionContext>& expCtx) { intrusive_ptr<DocumentSourceSingleDocumentTransformation> addFields( new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAddFields::create(addFieldsSpec), "$addFields")); - addFields->injectExpressionContext(expCtx); + expCtx, ParsedAddFields::create(expCtx, addFieldsSpec), "$addFields")); return addFields; } diff --git a/src/mongo/db/pipeline/document_source_add_fields.h b/src/mongo/db/pipeline/document_source_add_fields.h index b1ce0415ab3..77fe4d9d4e9 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.h +++ b/src/mongo/db/pipeline/document_source_add_fields.h @@ -51,6 +51,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); private: + // It is illegal to construct a DocumentSourceAddFields directly, use create() or + // createFromBson() instead. DocumentSourceAddFields() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket.cpp b/src/mongo/db/pipeline/document_source_bucket.cpp index 05152b7e313..7c671ea2998 100644 --- a/src/mongo/db/pipeline/document_source_bucket.cpp +++ b/src/mongo/db/pipeline/document_source_bucket.cpp @@ -43,9 +43,11 @@ REGISTER_MULTI_STAGE_ALIAS(bucket, DocumentSourceBucket::createFromBson); namespace { -intrusive_ptr<ExpressionConstant> getExpressionConstant(BSONElement expressionElem, - VariablesParseState vps) { - auto expr = Expression::parseOperand(expressionElem, vps)->optimize(); +intrusive_ptr<ExpressionConstant> getExpressionConstant( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expressionElem, + VariablesParseState vps) { + auto expr = Expression::parseOperand(expCtx, expressionElem, vps)->optimize(); return dynamic_cast<ExpressionConstant*>(expr.get()); } } // namespace @@ -95,7 +97,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( argument.type() == BSONType::Array); for (auto&& boundaryElem : argument.embeddedObject()) { - auto exprConst = getExpressionConstant(boundaryElem, vps); + auto exprConst = getExpressionConstant(pExpCtx, boundaryElem, vps); uassert(40191, str::stream() << "The $bucket 'boundaries' field must be an array of " "constant values, but found value: " @@ -144,7 +146,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( } else if ("default" == argName) { // If there is a default, make sure that it parses to a constant expression then add // default to switch. - auto exprConst = getExpressionConstant(argument, vps); + auto exprConst = getExpressionConstant(pExpCtx, argument, vps); uassert(40195, str::stream() << "The $bucket 'default' field must be a constant expression, but found: " diff --git a/src/mongo/db/pipeline/document_source_bucket.h b/src/mongo/db/pipeline/document_source_bucket.h index cd7fe31b14e..aa83ee07cc2 100644 --- a/src/mongo/db/pipeline/document_source_bucket.h +++ b/src/mongo/db/pipeline/document_source_bucket.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceBucket directly, use createFromBson() instead. DocumentSourceBucket() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.cpp b/src/mongo/db/pipeline/document_source_bucket_auto.cpp index c9a09354434..b02e8d0226f 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto.cpp @@ -196,7 +196,8 @@ void DocumentSourceBucketAuto::populateBuckets() { } // Initialize the current bucket. - Bucket currentBucket(currentValue.first, currentValue.first, _accumulatorFactories); + Bucket currentBucket( + pExpCtx, currentValue.first, currentValue.first, _accumulatorFactories); // Add the first value into the current bucket. addDocumentToBucket(currentValue, currentBucket); @@ -267,13 +268,14 @@ void DocumentSourceBucketAuto::populateBuckets() { } } -DocumentSourceBucketAuto::Bucket::Bucket(Value min, +DocumentSourceBucketAuto::Bucket::Bucket(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Value min, Value max, vector<Accumulator::Factory> accumulatorFactories) : _min(min), _max(max) { _accums.reserve(accumulatorFactories.size()); for (auto&& factory : accumulatorFactories) { - _accums.push_back(factory()); + _accums.push_back(factory(expCtx)); } } @@ -344,7 +346,7 @@ Value DocumentSourceBucketAuto::serialize(bool explain) const { const size_t nOutputFields = _fieldNames.size(); MutableDocument outputSpec(nOutputFields); for (size_t i = 0; i < nOutputFields; i++) { - intrusive_ptr<Accumulator> accum = _accumulatorFactories[i](); + intrusive_ptr<Accumulator> accum = _accumulatorFactories[i](pExpCtx); outputSpec[_fieldNames[i]] = Value{Document{{accum->getOpName(), _expressions[i]->serialize(explain)}}}; } @@ -405,14 +407,16 @@ DocumentSourceBucketAuto::DocumentSourceBucketAuto( namespace { -boost::intrusive_ptr<Expression> parseGroupByExpression(const BSONElement& groupByField, - const VariablesParseState& vps) { +boost::intrusive_ptr<Expression> parseGroupByExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& groupByField, + const VariablesParseState& vps) { if (groupByField.type() == BSONType::Object && groupByField.embeddedObject().firstElementFieldName()[0] == '$') { - return Expression::parseObject(groupByField.embeddedObject(), vps); + return Expression::parseObject(expCtx, groupByField.embeddedObject(), vps); } else if (groupByField.type() == BSONType::String && groupByField.valueStringData()[0] == '$') { - return ExpressionFieldPath::parse(groupByField.str(), vps); + return ExpressionFieldPath::parse(expCtx, groupByField.str(), vps); } else { uasserted( 40239, @@ -440,7 +444,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( for (auto&& argument : elem.Obj()) { const auto argName = argument.fieldNameStringData(); if ("groupBy" == argName) { - groupByExpression = parseGroupByExpression(argument, vps); + groupByExpression = parseGroupByExpression(pExpCtx, argument, vps); } else if ("buckets" == argName) { Value bucketsValue = Value(argument); @@ -467,7 +471,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( for (auto&& outputField : argument.embeddedObject()) { accumulationStatements.push_back( - AccumulationStatement::parseAccumulationStatement(outputField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps)); } } else if ("granularity" == argName) { uassert(40261, @@ -475,7 +479,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( << "The $bucketAuto 'granularity' field must be a string, but found type: " << typeName(argument.type()), argument.type() == BSONType::String); - granularityRounder = GranularityRounder::getGranularityRounder(argument.str()); + granularityRounder = GranularityRounder::getGranularityRounder(pExpCtx, argument.str()); } else { uasserted(40245, str::stream() << "Unrecognized option to $bucketAuto: " << argName); } diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.h b/src/mongo/db/pipeline/document_source_bucket_auto.h index 9279bcd52b9..61f7da30c71 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.h +++ b/src/mongo/db/pipeline/document_source_bucket_auto.h @@ -93,7 +93,10 @@ private: // struct for holding information about a bucket. struct Bucket { - Bucket(Value min, Value max, std::vector<Accumulator::Factory> accumulatorFactories); + Bucket(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Value min, + Value max, + std::vector<Accumulator::Factory> accumulatorFactories); Value _min; Value _max; std::vector<boost::intrusive_ptr<Accumulator>> _accums; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp index 530948d2d3d..2d71b676024 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp @@ -355,7 +355,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToCorrectlySpillToDisk) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -397,7 +397,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -643,7 +643,7 @@ void assertCannotSpillToDisk(const boost::intrusive_ptr<ExpressionContext>& expC VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -685,7 +685,7 @@ TEST_F(BucketAutoTests, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, diff --git a/src/mongo/db/pipeline/document_source_bucket_test.cpp b/src/mongo/db/pipeline/document_source_bucket_test.cpp index b93916f4347..bdfb3635bb7 100644 --- a/src/mongo/db/pipeline/document_source_bucket_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_test.cpp @@ -156,9 +156,6 @@ class InvalidBucketSpec : public AggregationContextFixture { public: vector<intrusive_ptr<DocumentSource>> createBucket(BSONObj bucketSpec) { auto sources = DocumentSourceBucket::createFromBson(bucketSpec.firstElement(), getExpCtx()); - for (auto&& source : sources) { - source->injectExpressionContext(getExpCtx()); - } return sources; } }; diff --git a/src/mongo/db/pipeline/document_source_count.h b/src/mongo/db/pipeline/document_source_count.h index e1817fbfca3..50fbccdf41b 100644 --- a/src/mongo/db/pipeline/document_source_count.h +++ b/src/mongo/db/pipeline/document_source_count.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceCount directly, use createFromBson() instead. DocumentSourceCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_cursor.cpp b/src/mongo/db/pipeline/document_source_cursor.cpp index 5c5be2aef1f..2def98d25bd 100644 --- a/src/mongo/db/pipeline/document_source_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_cursor.cpp @@ -221,12 +221,6 @@ Value DocumentSourceCursor::serialize(bool explain) const { return Value(DOC(getSourceName() << out.freezeToValue())); } -void DocumentSourceCursor::doInjectExpressionContext() { - if (_limit) { - _limit->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceCursor::detachFromOperationContext() { if (_exec) { _exec->detachFromOperationContext(); @@ -259,7 +253,6 @@ intrusive_ptr<DocumentSourceCursor> DocumentSourceCursor::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { intrusive_ptr<DocumentSourceCursor> source( new DocumentSourceCursor(ns, std::move(exec), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_cursor.h b/src/mongo/db/pipeline/document_source_cursor.h index ebbf9eff64d..799834f1487 100644 --- a/src/mongo/db/pipeline/document_source_cursor.h +++ b/src/mongo/db/pipeline/document_source_cursor.h @@ -132,9 +132,6 @@ public: const PlanSummaryStats& getPlanSummaryStats() const; -protected: - void doInjectExpressionContext() final; - private: DocumentSourceCursor(const std::string& ns, std::unique_ptr<PlanExecutor> exec, diff --git a/src/mongo/db/pipeline/document_source_facet.cpp b/src/mongo/db/pipeline/document_source_facet.cpp index 7b8cd82de0f..9b03b233051 100644 --- a/src/mongo/db/pipeline/document_source_facet.cpp +++ b/src/mongo/db/pipeline/document_source_facet.cpp @@ -203,12 +203,6 @@ intrusive_ptr<DocumentSource> DocumentSourceFacet::optimize() { return this; } -void DocumentSourceFacet::doInjectExpressionContext() { - for (auto&& facet : _facets) { - facet.pipeline->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceFacet::doInjectMongodInterface(std::shared_ptr<MongodInterface> mongod) { for (auto&& facet : _facets) { for (auto&& stage : facet.pipeline->getSources()) { diff --git a/src/mongo/db/pipeline/document_source_facet.h b/src/mongo/db/pipeline/document_source_facet.h index b0a9a0867ee..921b2e87609 100644 --- a/src/mongo/db/pipeline/document_source_facet.h +++ b/src/mongo/db/pipeline/document_source_facet.h @@ -43,7 +43,7 @@ namespace mongo { class BSONElement; class TeeBuffer; class DocumentSourceTeeConsumer; -struct ExpressionContext; +class ExpressionContext; class NamespaceString; /** @@ -98,11 +98,6 @@ public: boost::intrusive_ptr<DocumentSource> optimize() final; /** - * Injects the expression context into inner pipelines. - */ - void doInjectExpressionContext() final; - - /** * Takes a union of all sub-pipelines, and adds them to 'deps'. */ GetDepsReturn getDependencies(DepsTracker* deps) const final; diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp index 8c4a2ca3ff5..17e9997bbf5 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near.cpp @@ -178,7 +178,6 @@ void DocumentSourceGeoNear::runCommand() { intrusive_ptr<DocumentSourceGeoNear> DocumentSourceGeoNear::create( const intrusive_ptr<ExpressionContext>& pCtx) { intrusive_ptr<DocumentSourceGeoNear> source(new DocumentSourceGeoNear(pCtx)); - source->injectExpressionContext(pCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.cpp b/src/mongo/db/pipeline/document_source_graph_lookup.cpp index 7743e9c6058..03515560057 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup.cpp @@ -164,7 +164,7 @@ DocumentSource::GetNextResult DocumentSourceGraphLookUp::getNextUnwound() { void DocumentSourceGraphLookUp::dispose() { _cache.clear(); - _frontier->clear(); + _frontier.clear(); _visited.clear(); pSource->dispose(); } @@ -180,7 +180,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { auto matchStage = makeMatchStageFromFrontier(&cached); ValueUnorderedSet queried = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _frontier->swap(queried); + _frontier.swap(queried); _frontierUsageBytes = 0; // Process cached values, populating '_frontier' for the next iteration of search. @@ -219,7 +219,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { } while (shouldPerformAnotherQuery && depth < std::numeric_limits<long long>::max() && (!_maxDepth || depth <= *_maxDepth)); - _frontier->clear(); + _frontier.clear(); _frontierUsageBytes = 0; } @@ -264,12 +264,12 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon Value recurseOn = Value(elem); if (recurseOn.isArray()) { for (auto&& subElem : recurseOn.getArray()) { - _frontier->insert(subElem); + _frontier.insert(subElem); _frontierUsageBytes += subElem.getApproximateSize(); } } else if (!recurseOn.missing()) { // Don't recurse on a missing value. - _frontier->insert(recurseOn); + _frontier.insert(recurseOn); _frontierUsageBytes += recurseOn.getApproximateSize(); } } @@ -304,13 +304,13 @@ void DocumentSourceGraphLookUp::addToCache(const BSONObj& result, boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(BSONObjSet* cached) { // Add any cached values to 'cached' and remove them from '_frontier'. - for (auto it = _frontier->begin(); it != _frontier->end();) { + for (auto it = _frontier.begin(); it != _frontier.end();) { if (auto entry = _cache[*it]) { for (auto&& obj : *entry) { cached->insert(obj); } size_t valueSize = it->getApproximateSize(); - it = _frontier->erase(it); + it = _frontier.erase(it); // If the cached value increased in size while in the cache, we don't want to underflow // '_frontierUsageBytes'. @@ -340,7 +340,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B BSONObjBuilder subObj(connectToObj.subobjStart(_connectToField.fullPath())); { BSONArrayBuilder in(subObj.subarrayStart("$in")); - for (auto&& value : *_frontier) { + for (auto&& value : _frontier) { in << value; } } @@ -349,7 +349,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B } } - return _frontier->empty() ? boost::none : boost::optional<BSONObj>(match.obj()); + return _frontier.empty() ? boost::none : boost::optional<BSONObj>(match.obj()); } void DocumentSourceGraphLookUp::performSearch() { @@ -363,11 +363,11 @@ void DocumentSourceGraphLookUp::performSearch() { // If _startWith evaluates to an array, treat each value as a separate starting point. if (startingValue.isArray()) { for (auto value : startingValue.getArray()) { - _frontier->insert(value); + _frontier.insert(value); _frontierUsageBytes += value.getApproximateSize(); } } else { - _frontier->insert(startingValue); + _frontier.insert(startingValue); _frontierUsageBytes += startingValue.getApproximateSize(); } @@ -460,24 +460,6 @@ void DocumentSourceGraphLookUp::serializeToArray(std::vector<Value>& array, bool } } -void DocumentSourceGraphLookUp::doInjectExpressionContext() { - _startWith->injectExpressionContext(pExpCtx); - - auto it = pExpCtx->resolvedNamespaces.find(_from.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); - - _frontier = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _cache.setValueComparator(pExpCtx->getValueComparator()); -} - void DocumentSourceGraphLookUp::doDetachFromOperationContext() { _fromExpCtx->opCtx = nullptr; } @@ -506,9 +488,19 @@ DocumentSourceGraphLookUp::DocumentSourceGraphLookUp( _additionalFilter(additionalFilter), _depthField(depthField), _maxDepth(maxDepth), + _frontier(pExpCtx->getValueComparator().makeUnorderedValueSet()), _visited(ValueComparator::kInstance.makeUnorderedValueMap<BSONObj>()), - _cache(expCtx->getValueComparator()), - _unwind(unwindSrc) {} + _cache(pExpCtx->getValueComparator()), + _unwind(unwindSrc) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_from); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} intrusive_ptr<DocumentSourceGraphLookUp> DocumentSourceGraphLookUp::create( const intrusive_ptr<ExpressionContext>& expCtx, @@ -533,8 +525,6 @@ intrusive_ptr<DocumentSourceGraphLookUp> DocumentSourceGraphLookUp::create( maxDepth, unwindSrc)); source->_variables.reset(new Variables()); - - source->injectExpressionContext(expCtx); return source; } @@ -556,7 +546,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGraphLookUp::createFromBson( const auto argName = argument.fieldNameStringData(); if (argName == "startWith") { - startWith = Expression::parseOperand(argument, vps); + startWith = Expression::parseOperand(expCtx, argument, vps); continue; } else if (argName == "maxDepth") { uassert(40100, diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.h b/src/mongo/db/pipeline/document_source_graph_lookup.h index 01473bf44fc..be12637d404 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.h +++ b/src/mongo/db/pipeline/document_source_graph_lookup.h @@ -94,9 +94,6 @@ public: static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); -protected: - void doInjectExpressionContext() final; - private: DocumentSourceGraphLookUp( const boost::intrusive_ptr<ExpressionContext>& expCtx, @@ -188,9 +185,7 @@ private: size_t _frontierUsageBytes = 0; // Only used during the breadth-first search, tracks the set of values on the current frontier. - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected. - boost::optional<ValueUnorderedSet> _frontier; + ValueUnorderedSet _frontier; // Tracks nodes that have been discovered for a given input. Keys are the '_id' value of the // document from the foreign collection, value is the document itself. The keys are compared diff --git a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp index 6350a8c39e1..53e3e6ea1ac 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp @@ -70,7 +70,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_results)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -90,17 +89,18 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque<DocumentSource::GetNextResult> fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -119,17 +119,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"to", 1}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -147,18 +148,19 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque<DocumentSource::GetNextResult> fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, boost::none); - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - unwindStage); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + unwindStage); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); graphLookupStage->setSource(inputMock.get()); @@ -190,17 +192,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document(to1), Document(to2), Document(to0from1), Document(to0from2)}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -254,14 +257,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePauses) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -322,7 +325,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); const bool preserveNullAndEmptyArrays = false; const boost::optional<std::string> includeArrayIndex = boost::none; @@ -335,7 +338,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -387,14 +390,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -409,7 +412,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, std::string("arrIndex")); auto graphLookupStage = @@ -418,7 +421,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbs "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -437,7 +440,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar auto inputMock = DocumentSourceMock::create(Document({{"_id", 0}, {"a", 1}, {"b", 2}})); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); std::deque<DocumentSource::GetNextResult> fromContents{Document{{"_id", 0}, {"to", true}}, Document{{"_id", 1}, {"to", false}}}; @@ -447,9 +450,10 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar "results", "from", "to", - ExpressionCompare::create(ExpressionCompare::GT, - ExpressionFieldPath::create("a"), - ExpressionFieldPath::create("b")), + ExpressionCompare::create(expCtx, + ExpressionCompare::GT, + ExpressionFieldPath::create(expCtx, "a"), + ExpressionFieldPath::create(expCtx, "b")), boost::none, boost::none, boost::none, diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp index cf91c8eff5f..a70a65c06b9 100644 --- a/src/mongo/db/pipeline/document_source_group.cpp +++ b/src/mongo/db/pipeline/document_source_group.cpp @@ -198,23 +198,6 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() { return this; } -void DocumentSourceGroup::doInjectExpressionContext() { - // Groups map must respect new comparator. - _groups = pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>(); - - for (auto&& idExpr : _idExpressions) { - idExpr->injectExpressionContext(pExpCtx); - } - - for (auto&& expr : vpExpression) { - expr->injectExpressionContext(pExpCtx); - } - - for (auto&& accum : _currentAccumulators) { - accum->injectExpressionContext(pExpCtx); - } -} - Value DocumentSourceGroup::serialize(bool explain) const { MutableDocument insides; @@ -235,7 +218,7 @@ Value DocumentSourceGroup::serialize(bool explain) const { // Add the remaining fields. const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { - intrusive_ptr<Accumulator> accum = vpAccumulatorFactory[i](); + intrusive_ptr<Accumulator> accum = vpAccumulatorFactory[i](pExpCtx); insides[vFieldName[i]] = Value(DOC(accum->getOpName() << vpExpression[i]->serialize(explain))); } @@ -280,7 +263,6 @@ intrusive_ptr<DocumentSourceGroup> DocumentSourceGroup::create( groupStage->addAccumulator(statement); } groupStage->_variables = stdx::make_unique<Variables>(numVariables); - groupStage->injectExpressionContext(pExpCtx); return groupStage; } @@ -292,6 +274,7 @@ DocumentSourceGroup::DocumentSourceGroup(const intrusive_ptr<ExpressionContext>& _inputSort(BSONObj()), _streaming(false), _initialized(false), + _groups(pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>()), _spilled(false), _extSortAllowed(pExpCtx->extSortAllowed && !pExpCtx->inRouter) {} @@ -303,7 +286,7 @@ void DocumentSourceGroup::addAccumulator(AccumulationStatement accumulationState namespace { -intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContext> expCtx, +intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContext>& expCtx, BSONElement groupField, const VariablesParseState& vps) { if (groupField.type() == Object && !groupField.Obj().isEmpty()) { @@ -312,18 +295,18 @@ intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContex const BSONObj idKeyObj = groupField.Obj(); if (idKeyObj.firstElementFieldName()[0] == '$') { // grouping on a $op expression - return Expression::parseObject(idKeyObj, vps); + return Expression::parseObject(expCtx, idKeyObj, vps); } else { for (auto&& field : idKeyObj) { uassert(17390, "$group does not support inclusion-style expressions", !field.isNumber() && field.type() != Bool); } - return ExpressionObject::parse(idKeyObj, vps); + return ExpressionObject::parse(expCtx, idKeyObj, vps); } } else if (groupField.type() == String && groupField.valuestr()[0] == '$') { // grouping on a field path. - return ExpressionFieldPath::parse(groupField.str(), vps); + return ExpressionFieldPath::parse(expCtx, groupField.str(), vps); } else { // constant id - single group return ExpressionConstant::create(expCtx, Value(groupField)); @@ -377,7 +360,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::createFromBson( } else { // Any other field will be treated as an accumulator specification. pGroup->addAccumulator( - AccumulationStatement::parseAccumulationStatement(groupField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, groupField, vps)); } } @@ -493,8 +476,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Set up accumulators. _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } // We only need to load the first document. @@ -543,8 +525,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Add the accumulators group.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - group.push_back(vpAccumulatorFactory[i]()); - group.back()->injectExpressionContext(pExpCtx); + group.push_back(vpAccumulatorFactory[i](pExpCtx)); } } else { for (size_t i = 0; i < numAccumulators; i++) { @@ -599,8 +580,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // prepare current to accumulate data _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } verify(_sorterIterator->more()); // we put data in, we should get something out. @@ -876,7 +856,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::getMergeSource() { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); /* the merger will use the same grouping key */ - pMerger->setIdExpression(ExpressionFieldPath::parse("$$ROOT._id", vps)); + pMerger->setIdExpression(ExpressionFieldPath::parse(pExpCtx, "$$ROOT._id", vps)); const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { @@ -888,13 +868,13 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::getMergeSource() { expression or constant. Here, we accumulate the output of the same name from the prior group. */ - pMerger->addAccumulator({vFieldName[i], - vpAccumulatorFactory[i], - ExpressionFieldPath::parse("$$ROOT." + vFieldName[i], vps)}); + pMerger->addAccumulator( + {vFieldName[i], + vpAccumulatorFactory[i], + ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + vFieldName[i], vps)}); } pMerger->_variables.reset(new Variables(idGenerator.getIdCount())); - pMerger->injectExpressionContext(pExpCtx); return pMerger; } diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h index 76d9282ebb9..e79ce631817 100644 --- a/src/mongo/db/pipeline/document_source_group.h +++ b/src/mongo/db/pipeline/document_source_group.h @@ -96,9 +96,6 @@ public: boost::intrusive_ptr<DocumentSource> getShardSource() final; boost::intrusive_ptr<DocumentSource> getMergeSource() final; -protected: - void doInjectExpressionContext() final; - private: explicit DocumentSourceGroup(const boost::intrusive_ptr<ExpressionContext>& pExpCtx, size_t maxMemoryUsageBytes = kDefaultMaxMemoryUsageBytes); diff --git a/src/mongo/db/pipeline/document_source_group_test.cpp b/src/mongo/db/pipeline/document_source_group_test.cpp index 599e282e6b9..25b10ddcb5b 100644 --- a/src/mongo/db/pipeline/document_source_group_test.cpp +++ b/src/mongo/db/pipeline/document_source_group_test.cpp @@ -45,7 +45,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/dbtests/dbtests.h" @@ -110,8 +110,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -150,8 +150,8 @@ TEST_F(DocumentSourceGroupTest, ShouldErrorIfNotAllowedToSpillToDiskAndResultSet VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -173,8 +173,8 @@ TEST_F(DocumentSourceGroupTest, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -204,7 +204,8 @@ public: Base() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))), + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))), _tempDir("DocumentSourceGroupTest") {} protected: @@ -212,15 +213,14 @@ protected: BSONObj namedSpec = BSON("$group" << spec); BSONElement specElement = namedSpec.firstElement(); - intrusive_ptr<ExpressionContext> expressionContext = - new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); + intrusive_ptr<ExpressionContextForTest> expressionContext = + new ExpressionContextForTest(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); expressionContext->inShard = inShard; expressionContext->inRouter = inRouter; // Won't spill to disk properly if it needs to. expressionContext->tempDir = _tempDir.path(); _group = DocumentSourceGroup::createFromBson(specElement, expressionContext); - _group->injectExpressionContext(expressionContext); assertRoundTrips(_group); } DocumentSourceGroup* group() { @@ -234,7 +234,7 @@ protected: ASSERT(source->getNext().isEOF()); } - intrusive_ptr<ExpressionContext> ctx() const { + intrusive_ptr<ExpressionContextForTest> ctx() const { return _ctx; } @@ -251,7 +251,7 @@ private: } std::unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr<ExpressionContext> _ctx; + intrusive_ptr<ExpressionContextForTest> _ctx; intrusive_ptr<DocumentSource> _group; TempDir _tempDir; }; diff --git a/src/mongo/db/pipeline/document_source_limit.cpp b/src/mongo/db/pipeline/document_source_limit.cpp index 634d635b22d..797c6ed6652 100644 --- a/src/mongo/db/pipeline/document_source_limit.cpp +++ b/src/mongo/db/pipeline/document_source_limit.cpp @@ -93,7 +93,6 @@ intrusive_ptr<DocumentSourceLimit> DocumentSourceLimit::create( const intrusive_ptr<ExpressionContext>& pExpCtx, long long limit) { uassert(15958, "the limit must be positive", limit > 0); intrusive_ptr<DocumentSourceLimit> source(new DocumentSourceLimit(pExpCtx, limit)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_lookup.cpp b/src/mongo/db/pipeline/document_source_lookup.cpp index 5371542505c..a0923db276e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_lookup.cpp @@ -54,7 +54,16 @@ DocumentSourceLookUp::DocumentSourceLookUp(NamespaceString fromNs, _as(std::move(as)), _localField(std::move(localField)), _foreignField(foreignField), - _foreignFieldFieldName(std::move(foreignField)) {} + _foreignFieldFieldName(std::move(foreignField)) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_fromNs); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} std::unique_ptr<LiteParsedDocumentSourceOneForeignCollection> DocumentSourceLookUp::liteParse( const AggregationRequest& request, const BSONElement& spec) { @@ -454,19 +463,6 @@ DocumentSource::GetDepsReturn DocumentSourceLookUp::getDependencies(DepsTracker* return SEE_NEXT; } -void DocumentSourceLookUp::doInjectExpressionContext() { - auto it = pExpCtx->resolvedNamespaces.find(_fromNs.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); -} - void DocumentSourceLookUp::doDetachFromOperationContext() { if (_pipeline) { // We have a pipeline we're going to be executing across multiple calls to getNext(), so we diff --git a/src/mongo/db/pipeline/document_source_lookup.h b/src/mongo/db/pipeline/document_source_lookup.h index 13118828b65..e179abfe73e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.h +++ b/src/mongo/db/pipeline/document_source_lookup.h @@ -113,9 +113,6 @@ public: _handlingUnwind = true; } -protected: - void doInjectExpressionContext() final; - private: DocumentSourceLookUp(NamespaceString fromNs, std::string as, diff --git a/src/mongo/db/pipeline/document_source_lookup_test.cpp b/src/mongo/db/pipeline/document_source_lookup_test.cpp index 0f54e3741c4..9da414ec6f6 100644 --- a/src/mongo/db/pipeline/document_source_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_lookup_test.cpp @@ -54,6 +54,10 @@ using std::vector; using DocumentSourceLookUpTest = AggregationContextFixture; TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + intrusive_ptr<DocumentSourceMock> source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -63,7 +67,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { {"as", "d.e"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -73,6 +77,10 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { } TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + intrusive_ptr<DocumentSourceMock> source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -82,7 +90,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { {"as", "d"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -158,7 +166,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_mockResults)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -171,7 +178,7 @@ private: TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -191,7 +198,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); // Mock out the foreign collection. deque<DocumentSource::GetNextResult> mockForeignContents{Document{{"_id", 0}}, @@ -222,7 +228,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -247,8 +253,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); - // Mock out the foreign collection. deque<DocumentSource::GetNextResult> mockForeignContents{Document{{"_id", 0}}, Document{{"_id", 1}}}; @@ -276,7 +280,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -297,7 +301,7 @@ TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { TEST_F(DocumentSourceLookUpTest, LookupReportsFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", diff --git a/src/mongo/db/pipeline/document_source_match.cpp b/src/mongo/db/pipeline/document_source_match.cpp index 18cb22db71d..494d6589099 100644 --- a/src/mongo/db/pipeline/document_source_match.cpp +++ b/src/mongo/db/pipeline/document_source_match.cpp @@ -401,7 +401,7 @@ DocumentSourceMatch::splitSourceBy(const std::set<std::string>& fields) { boost::intrusive_ptr<DocumentSourceMatch> DocumentSourceMatch::descendMatchOnPath( MatchExpression* matchExpr, const std::string& descendOn, - intrusive_ptr<ExpressionContext> expCtx) { + const intrusive_ptr<ExpressionContext>& expCtx) { expression::mapOver(matchExpr, [&descendOn](MatchExpression* node, std::string path) -> void { // Cannot call this method on a $match including a $elemMatch. invariant(node->matchType() != MatchExpression::ELEM_MATCH_OBJECT && @@ -453,7 +453,6 @@ intrusive_ptr<DocumentSourceMatch> DocumentSourceMatch::create( BSONObj filter, const intrusive_ptr<ExpressionContext>& expCtx) { uassertNoDisallowedClauses(filter); intrusive_ptr<DocumentSourceMatch> match(new DocumentSourceMatch(filter, expCtx)); - match->injectExpressionContext(expCtx); return match; } @@ -490,10 +489,6 @@ void DocumentSourceMatch::addDependencies(DepsTracker* deps) const { }); } -void DocumentSourceMatch::doInjectExpressionContext() { - _expression->setCollator(pExpCtx->getCollator()); -} - DocumentSourceMatch::DocumentSourceMatch(const BSONObj& query, const intrusive_ptr<ExpressionContext>& pExpCtx) : DocumentSource(pExpCtx), _predicate(query.getOwned()), _isTextQuery(isTextQuery(query)) { diff --git a/src/mongo/db/pipeline/document_source_match.h b/src/mongo/db/pipeline/document_source_match.h index e7f3d4227d9..badfe1ed8b9 100644 --- a/src/mongo/db/pipeline/document_source_match.h +++ b/src/mongo/db/pipeline/document_source_match.h @@ -140,9 +140,7 @@ public: static boost::intrusive_ptr<DocumentSourceMatch> descendMatchOnPath( MatchExpression* matchExpr, const std::string& path, - boost::intrusive_ptr<ExpressionContext> expCtx); - - void doInjectExpressionContext(); + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: DocumentSourceMatch(const BSONObj& query, diff --git a/src/mongo/db/pipeline/document_source_merge_cursors.cpp b/src/mongo/db/pipeline/document_source_merge_cursors.cpp index 866259c663e..b03db6560df 100644 --- a/src/mongo/db/pipeline/document_source_merge_cursors.cpp +++ b/src/mongo/db/pipeline/document_source_merge_cursors.cpp @@ -57,7 +57,6 @@ intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { intrusive_ptr<DocumentSourceMergeCursors> source( new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_mock.cpp b/src/mongo/db/pipeline/document_source_mock.cpp index 686f4ef6f2e..3918bbd5ad7 100644 --- a/src/mongo/db/pipeline/document_source_mock.cpp +++ b/src/mongo/db/pipeline/document_source_mock.cpp @@ -31,6 +31,8 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" namespace mongo { @@ -38,7 +40,7 @@ using boost::intrusive_ptr; using std::deque; DocumentSourceMock::DocumentSourceMock(deque<GetNextResult> results) - : DocumentSource(nullptr), + : DocumentSource(new ExpressionContextForTest()), queue(std::move(results)), sorts(SimpleBSONObjComparator::kInstance.makeBSONObjSet()) {} diff --git a/src/mongo/db/pipeline/document_source_mock.h b/src/mongo/db/pipeline/document_source_mock.h index 7a600483661..7b966f5c418 100644 --- a/src/mongo/db/pipeline/document_source_mock.h +++ b/src/mongo/db/pipeline/document_source_mock.h @@ -79,10 +79,6 @@ public: return this; } - void doInjectExpressionContext() override { - isExpCtxInjected = true; - } - // Return documents from front of queue. std::deque<GetNextResult> queue; diff --git a/src/mongo/db/pipeline/document_source_project.cpp b/src/mongo/db/pipeline/document_source_project.cpp index fd44d534dde..a0ee096f36e 100644 --- a/src/mongo/db/pipeline/document_source_project.cpp +++ b/src/mongo/db/pipeline/document_source_project.cpp @@ -49,8 +49,7 @@ REGISTER_DOCUMENT_SOURCE(project, intrusive_ptr<DocumentSource> DocumentSourceProject::create( BSONObj projectSpec, const intrusive_ptr<ExpressionContext>& expCtx) { intrusive_ptr<DocumentSource> project(new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAggregationProjection::create(projectSpec), "$project")); - project->injectExpressionContext(expCtx); + expCtx, ParsedAggregationProjection::create(expCtx, projectSpec), "$project")); return project; } diff --git a/src/mongo/db/pipeline/document_source_project.h b/src/mongo/db/pipeline/document_source_project.h index 869dce512b1..4ed9db432d8 100644 --- a/src/mongo/db/pipeline/document_source_project.h +++ b/src/mongo/db/pipeline/document_source_project.h @@ -53,6 +53,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceProject directly, use create() or createFromBson() + // instead. DocumentSourceProject() = default; }; diff --git a/src/mongo/db/pipeline/document_source_redact.cpp b/src/mongo/db/pipeline/document_source_redact.cpp index 86b1c00960d..0096017a2f8 100644 --- a/src/mongo/db/pipeline/document_source_redact.cpp +++ b/src/mongo/db/pipeline/document_source_redact.cpp @@ -163,10 +163,6 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::optimize() { return this; } -void DocumentSourceRedact::doInjectExpressionContext() { - _expression->injectExpressionContext(pExpCtx); -} - Value DocumentSourceRedact::serialize(bool explain) const { return Value(DOC(getSourceName() << _expression.get()->serialize(explain))); } @@ -179,7 +175,7 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::createFromBson( Variables::Id decendId = vps.defineVariable("DESCEND"); Variables::Id pruneId = vps.defineVariable("PRUNE"); Variables::Id keepId = vps.defineVariable("KEEP"); - intrusive_ptr<Expression> expression = Expression::parseOperand(elem, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, elem, vps); intrusive_ptr<DocumentSourceRedact> source = new DocumentSourceRedact(expCtx, expression); // TODO figure out how much of this belongs in constructor and how much here. diff --git a/src/mongo/db/pipeline/document_source_redact.h b/src/mongo/db/pipeline/document_source_redact.h index f056411f765..b9cf6242d42 100644 --- a/src/mongo/db/pipeline/document_source_redact.h +++ b/src/mongo/db/pipeline/document_source_redact.h @@ -47,8 +47,6 @@ public: Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); diff --git a/src/mongo/db/pipeline/document_source_replace_root.cpp b/src/mongo/db/pipeline/document_source_replace_root.cpp index 6f1844ed281..1ae82d5cf3a 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.cpp +++ b/src/mongo/db/pipeline/document_source_replace_root.cpp @@ -87,17 +87,14 @@ public: return DocumentSource::EXHAUSTIVE_FIELDS; } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& pExpCtx) final { - _newRoot->injectExpressionContext(pExpCtx); - } - DocumentSource::GetModPathsReturn getModifiedPaths() const final { // Replaces the entire root, so all paths are modified. return {DocumentSource::GetModPathsReturn::Type::kAllPaths, std::set<std::string>{}}; } // Create the replaceRoot transformer. Uasserts on invalid input. - static std::unique_ptr<ReplaceRootTransformation> create(const BSONElement& spec) { + static std::unique_ptr<ReplaceRootTransformation> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONElement& spec) { // Confirm that the stage was called with an object. uassert(40229, @@ -108,12 +105,12 @@ public: // Create the pointer, parse the stage, and return. std::unique_ptr<ReplaceRootTransformation> parsedReplaceRoot = stdx::make_unique<ReplaceRootTransformation>(); - parsedReplaceRoot->parse(spec); + parsedReplaceRoot->parse(expCtx, spec); return parsedReplaceRoot; } // Check for valid replaceRoot options, and populate internal state variables. - void parse(const BSONElement& spec) { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONElement& spec) { // We need a VariablesParseState in order to parse the 'newRoot' expression. VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); @@ -124,7 +121,7 @@ public: if (argName == "newRoot") { // Allows for field path, object, and other expressions. - _newRoot = Expression::parseOperand(argument, vps); + _newRoot = Expression::parseOperand(expCtx, argument, vps); } else { uasserted(40230, str::stream() << "unrecognized option to $replaceRoot stage: " << argName @@ -150,7 +147,7 @@ intrusive_ptr<DocumentSource> DocumentSourceReplaceRoot::createFromBson( BSONElement elem, const intrusive_ptr<ExpressionContext>& pExpCtx) { return new DocumentSourceSingleDocumentTransformation( - pExpCtx, ReplaceRootTransformation::create(elem), "$replaceRoot"); + pExpCtx, ReplaceRootTransformation::create(pExpCtx, elem), "$replaceRoot"); } } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_replace_root.h b/src/mongo/db/pipeline/document_source_replace_root.h index f3313ec5f7a..4656a6cdc2d 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.h +++ b/src/mongo/db/pipeline/document_source_replace_root.h @@ -48,6 +48,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceReplaceRoot directly, use createFromBson() + // instead. DocumentSourceReplaceRoot() = default; }; diff --git a/src/mongo/db/pipeline/document_source_sample.cpp b/src/mongo/db/pipeline/document_source_sample.cpp index a436d3e8a53..c7d98db7260 100644 --- a/src/mongo/db/pipeline/document_source_sample.cpp +++ b/src/mongo/db/pipeline/document_source_sample.cpp @@ -87,10 +87,6 @@ Value DocumentSourceSample::serialize(bool explain) const { return Value(DOC(getSourceName() << DOC("size" << _size))); } -void DocumentSourceSample::doInjectExpressionContext() { - _sortStage->injectExpressionContext(pExpCtx); -} - namespace { const BSONObj randSortSpec = BSON("$rand" << BSON("$meta" << "randVal")); diff --git a/src/mongo/db/pipeline/document_source_sample.h b/src/mongo/db/pipeline/document_source_sample.h index ed0eec9c180..28f1dd97b05 100644 --- a/src/mongo/db/pipeline/document_source_sample.h +++ b/src/mongo/db/pipeline/document_source_sample.h @@ -50,8 +50,6 @@ public: return _size; } - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp index f4b2299c9e8..9fc3b5bf105 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp @@ -52,6 +52,7 @@ DocumentSourceSampleFromRandomCursor::DocumentSourceSampleFromRandomCursor( : DocumentSource(pExpCtx), _size(size), _idField(std::move(idField)), + _seenDocs(pExpCtx->getValueComparator().makeUnorderedValueSet()), _nDocsInColl(nDocsInCollection) {} const char* DocumentSourceSampleFromRandomCursor::getSourceName() const { @@ -76,7 +77,7 @@ double smallestFromSampleOfUniform(PseudoRandom* prng, size_t N) { DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNext() { pExpCtx->checkForInterrupt(); - if (_seenDocs->size() >= static_cast<size_t>(_size)) + if (_seenDocs.size() >= static_cast<size_t>(_size)) return GetNextResult::makeEOF(); auto nextResult = getNextNonDuplicateDocument(); @@ -114,7 +115,7 @@ DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNextNonDu << nextInput.getDocument().toString(), !idField.missing()); - if (_seenDocs->insert(std::move(idField)).second) { + if (_seenDocs.insert(std::move(idField)).second) { return nextInput; } LOG(1) << "$sample encountered duplicate document: " @@ -147,10 +148,6 @@ DocumentSource::GetDepsReturn DocumentSourceSampleFromRandomCursor::getDependenc return SEE_NEXT; } -void DocumentSourceSampleFromRandomCursor::doInjectExpressionContext() { - _seenDocs = pExpCtx->getValueComparator().makeUnorderedValueSet(); -} - intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRandomCursor::create( const intrusive_ptr<ExpressionContext>& expCtx, long long size, @@ -158,7 +155,6 @@ intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRand long long nDocsInCollection) { intrusive_ptr<DocumentSourceSampleFromRandomCursor> source( new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection)); - source->injectExpressionContext(expCtx); return source; } } // mongo diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h index 6e8a2ae1d39..0d0ac39ca49 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h @@ -44,8 +44,6 @@ public: Value serialize(bool explain = false) const final; GetDepsReturn getDependencies(DepsTracker* deps) const final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSourceSampleFromRandomCursor> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, long long size, @@ -71,9 +69,8 @@ private: std::string _idField; // Keeps track of the documents that have been returned, since a random cursor is allowed to - // return duplicates. We use boost::optional to defer initialization until the ExpressionContext - // containing the correct comparator is injected. - boost::optional<ValueUnorderedSet> _seenDocs; + // return duplicates. + ValueUnorderedSet _seenDocs; // The approximate number of documents in the collection (includes orphans). const long long _nDocsInColl; diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp index 67dd7338395..35204d272a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp @@ -98,10 +98,6 @@ DocumentSource::GetDepsReturn DocumentSourceSingleDocumentTransformation::getDep return _parsedTransform->addDependencies(deps); } -void DocumentSourceSingleDocumentTransformation::doInjectExpressionContext() { - _parsedTransform->injectExpressionContext(pExpCtx); -} - DocumentSource::GetModPathsReturn DocumentSourceSingleDocumentTransformation::getModifiedPaths() const { return _parsedTransform->getModifiedPaths(); diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.h b/src/mongo/db/pipeline/document_source_single_document_transformation.h index b2db5b0c3ae..1251cb454a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.h +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.h @@ -57,8 +57,6 @@ public: virtual void optimize() = 0; virtual Document serialize(bool explain) const = 0; virtual DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const = 0; - virtual void injectExpressionContext( - const boost::intrusive_ptr<ExpressionContext>& pExpCtx) = 0; virtual GetModPathsReturn getModifiedPaths() const = 0; }; @@ -75,7 +73,6 @@ public: Value serialize(bool explain) const final; Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; DocumentSource::GetDepsReturn getDependencies(DepsTracker* deps) const final; GetModPathsReturn getModifiedPaths() const final; diff --git a/src/mongo/db/pipeline/document_source_skip.cpp b/src/mongo/db/pipeline/document_source_skip.cpp index 27bb6d1d2a6..ff528f0634e 100644 --- a/src/mongo/db/pipeline/document_source_skip.cpp +++ b/src/mongo/db/pipeline/document_source_skip.cpp @@ -96,7 +96,6 @@ Pipeline::SourceContainer::iterator DocumentSourceSkip::doOptimizeAt( intrusive_ptr<DocumentSourceSkip> DocumentSourceSkip::create( const intrusive_ptr<ExpressionContext>& pExpCtx, long long nToSkip) { intrusive_ptr<DocumentSourceSkip> skip(new DocumentSourceSkip(pExpCtx, nToSkip)); - skip->injectExpressionContext(pExpCtx); return skip; } diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp index 4f4f17da4f2..6554a892016 100644 --- a/src/mongo/db/pipeline/document_source_sort.cpp +++ b/src/mongo/db/pipeline/document_source_sort.cpp @@ -113,7 +113,7 @@ long long DocumentSourceSort::getLimit() const { void DocumentSourceSort::addKey(StringData fieldPath, bool ascending) { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - vSortKey.push_back(ExpressionFieldPath::parse("$$ROOT." + fieldPath.toString(), vps)); + vSortKey.push_back(ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + fieldPath.toString(), vps)); vAscending.push_back(ascending); } @@ -176,7 +176,6 @@ intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create( uint64_t maxMemoryUsageBytes) { intrusive_ptr<DocumentSourceSort> pSort(new DocumentSourceSort(pExpCtx)); pSort->_maxMemoryUsageBytes = maxMemoryUsageBytes; - pSort->injectExpressionContext(pExpCtx); pSort->_sort = sortOrder.getOwned(); for (auto&& keyField : sortOrder) { @@ -197,7 +196,7 @@ intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create( VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - pSort->vSortKey.push_back(ExpressionMeta::parse(metaDoc.firstElement(), vps)); + pSort->vSortKey.push_back(ExpressionMeta::parse(pExpCtx, metaDoc.firstElement(), vps)); // If sorting by textScore, sort highest scores first. If sorting by randVal, order // doesn't matter, so just always use descending. diff --git a/src/mongo/db/pipeline/document_source_sort_by_count.h b/src/mongo/db/pipeline/document_source_sort_by_count.h index 2444d048259..69bf8d2c5e0 100644 --- a/src/mongo/db/pipeline/document_source_sort_by_count.h +++ b/src/mongo/db/pipeline/document_source_sort_by_count.h @@ -44,6 +44,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceSortByCount directly, use createFromBson() + // instead. DocumentSourceSortByCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_tee_consumer.h b/src/mongo/db/pipeline/document_source_tee_consumer.h index 928ac8cba30..bfdbfbda1cb 100644 --- a/src/mongo/db/pipeline/document_source_tee_consumer.h +++ b/src/mongo/db/pipeline/document_source_tee_consumer.h @@ -38,7 +38,7 @@ namespace mongo { class Document; -struct ExpressionContext; +class ExpressionContext; class Value; /** diff --git a/src/mongo/db/pipeline/document_source_unwind.cpp b/src/mongo/db/pipeline/document_source_unwind.cpp index 368b29afbd4..683069597e0 100644 --- a/src/mongo/db/pipeline/document_source_unwind.cpp +++ b/src/mongo/db/pipeline/document_source_unwind.cpp @@ -183,7 +183,6 @@ intrusive_ptr<DocumentSourceUnwind> DocumentSourceUnwind::create( FieldPath(unwindPath), preserveNullAndEmptyArrays, indexPath ? FieldPath(*indexPath) : boost::optional<FieldPath>())); - source->injectExpressionContext(expCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_unwind_test.cpp b/src/mongo/db/pipeline/document_source_unwind_test.cpp index 48df219770a..8dcf1ed9a32 100644 --- a/src/mongo/db/pipeline/document_source_unwind_test.cpp +++ b/src/mongo/db/pipeline/document_source_unwind_test.cpp @@ -42,7 +42,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_source_unwind.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context.h" @@ -70,7 +70,8 @@ public: CheckResultsBase() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))) {} + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))) {} virtual ~CheckResultsBase() {} @@ -141,7 +142,7 @@ protected: return expectedIndexedResultSetString(); } - intrusive_ptr<ExpressionContext> ctx() const { + intrusive_ptr<ExpressionContextForTest> ctx() const { return _ctx; } @@ -248,7 +249,7 @@ private: unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr<ExpressionContext> _ctx; + intrusive_ptr<ExpressionContextForTest> _ctx; intrusive_ptr<DocumentSourceUnwind> _unwind; }; diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index e9de0eb3472..d9a087e08c1 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -189,17 +189,20 @@ string Expression::removeFieldPrefix(const string& prefixedField) { return string(pPrefixedField + 1); } -intrusive_ptr<Expression> Expression::parseObject(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { if (obj.isEmpty()) { - return ExpressionObject::create({}); + return ExpressionObject::create(expCtx, {}); } if (obj.firstElementFieldName()[0] == '$') { // Assume this is an expression like {$add: [...]}. - return parseExpression(obj, vps); + return parseExpression(expCtx, obj, vps); } - return ExpressionObject::parse(obj, vps); + return ExpressionObject::parse(expCtx, obj, vps); } namespace { @@ -214,7 +217,10 @@ void Expression::registerExpression(string key, Parser parser) { parserMap[key] = parser; } -intrusive_ptr<Expression> Expression::parseExpression(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { uassert(15983, str::stream() << "An object representing an expression must have exactly one " "field: " @@ -227,36 +233,40 @@ intrusive_ptr<Expression> Expression::parseExpression(BSONObj obj, const Variabl uassert(ErrorCodes::InvalidPipelineOperator, str::stream() << "Unrecognized expression '" << opName << "'", op != parserMap.end()); - return op->second(obj.firstElement(), vps); + return op->second(expCtx, obj.firstElement(), vps); } -Expression::ExpressionVector ExpressionNary::parseArguments(BSONElement exprElement, - const VariablesParseState& vps) { +Expression::ExpressionVector ExpressionNary::parseArguments( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { ExpressionVector out; if (exprElement.type() == Array) { BSONForEach(elem, exprElement.Obj()) { - out.push_back(Expression::parseOperand(elem, vps)); + out.push_back(Expression::parseOperand(expCtx, elem, vps)); } } else { // Assume it's an operand that accepts a single argument. - out.push_back(Expression::parseOperand(exprElement, vps)); + out.push_back(Expression::parseOperand(expCtx, exprElement, vps)); } return out; } -intrusive_ptr<Expression> Expression::parseOperand(BSONElement exprElement, - const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseOperand( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { BSONType type = exprElement.type(); if (type == String && exprElement.valuestr()[0] == '$') { /* if we got here, this is a field path expression */ - return ExpressionFieldPath::parse(exprElement.str(), vps); + return ExpressionFieldPath::parse(expCtx, exprElement.str(), vps); } else if (type == Object) { - return Expression::parseObject(exprElement.Obj(), vps); + return Expression::parseObject(expCtx, exprElement.Obj(), vps); } else if (type == Array) { - return ExpressionArray::parse(exprElement, vps); + return ExpressionArray::parse(expCtx, exprElement, vps); } else { - return ExpressionConstant::parse(exprElement, vps); + return ExpressionConstant::parse(expCtx, exprElement, vps); } } @@ -615,13 +625,13 @@ const char* ExpressionCeil::getOpName() const { intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create( const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) { - intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(pExpression)); - pNew->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(expCtx, pExpression)); return pNew; } -ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<Expression>& pTheExpression) - : Expression(), pExpression(pTheExpression) {} +ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext>& expCtx, + const intrusive_ptr<Expression>& pTheExpression) + : Expression(expCtx), pExpression(pTheExpression) {} intrusive_ptr<Expression> ExpressionCoerceToBool::optimize() { /* optimize the operand */ @@ -656,65 +666,68 @@ Value ExpressionCoerceToBool::serialize(bool explain) const { return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain)))); } -void ExpressionCoerceToBool::doInjectExpressionContext() { - // Inject our ExpressionContext into the operand. - pExpression->injectExpressionContext(getExpressionContext()); -} - /* ----------------------- ExpressionCompare --------------------------- */ REGISTER_EXPRESSION(cmp, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::CMP)); REGISTER_EXPRESSION(eq, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::EQ)); REGISTER_EXPRESSION(gt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GT)); REGISTER_EXPRESSION(gte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GTE)); REGISTER_EXPRESSION(lt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LT)); REGISTER_EXPRESSION(lte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LTE)); REGISTER_EXPRESSION(ne, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::NE)); -intrusive_ptr<Expression> ExpressionCompare::parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp op) { - intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(op); - ExpressionVector args = parseArguments(bsonExpr, vps); +intrusive_ptr<Expression> ExpressionCompare::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp op) { + intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, op); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } -ExpressionCompare::ExpressionCompare(CmpOp theCmpOp) : cmpOp(theCmpOp) {} - boost::intrusive_ptr<ExpressionCompare> ExpressionCompare::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp, const boost::intrusive_ptr<Expression>& exprLeft, const boost::intrusive_ptr<Expression>& exprRight) { - boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(cmpOp); + boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, cmpOp); expr->vpOperand = {exprLeft, exprRight}; return expr; } @@ -828,23 +841,26 @@ Value ExpressionCond::evaluateInternal(Variables* vars) const { return vpOperand[idx]->evaluateInternal(vars); } -intrusive_ptr<Expression> ExpressionCond::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionCond::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { if (expr.type() != Object) { - return Base::parse(expr, vps); + return Base::parse(expCtx, expr, vps); } verify(str::equals(expr.fieldName(), "$cond")); - intrusive_ptr<ExpressionCond> ret = new ExpressionCond(); + intrusive_ptr<ExpressionCond> ret = new ExpressionCond(expCtx); ret->vpOperand.resize(3); const BSONObj args = expr.embeddedObject(); BSONForEach(arg, args) { if (str::equals(arg.fieldName(), "if")) { - ret->vpOperand[0] = parseOperand(arg, vps); + ret->vpOperand[0] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "then")) { - ret->vpOperand[1] = parseOperand(arg, vps); + ret->vpOperand[1] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "else")) { - ret->vpOperand[2] = parseOperand(arg, vps); + ret->vpOperand[2] = parseOperand(expCtx, arg, vps); } else { uasserted(17083, str::stream() << "Unrecognized parameter to $cond: " << arg.fieldName()); @@ -865,20 +881,23 @@ const char* ExpressionCond::getOpName() const { /* ---------------------- ExpressionConstant --------------------------- */ -intrusive_ptr<Expression> ExpressionConstant::parse(BSONElement exprElement, - const VariablesParseState& vps) { - return new ExpressionConstant(Value(exprElement)); +intrusive_ptr<Expression> ExpressionConstant::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { + return new ExpressionConstant(expCtx, Value(exprElement)); } intrusive_ptr<ExpressionConstant> ExpressionConstant::create( const intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue) { - intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(pValue)); - pEC->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(expCtx, pValue)); return pEC; } -ExpressionConstant::ExpressionConstant(const Value& pTheValue) : pValue(pTheValue) {} +ExpressionConstant::ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const Value& pTheValue) + : Expression(expCtx), pValue(pTheValue) {} intrusive_ptr<Expression> ExpressionConstant::optimize() { @@ -907,8 +926,10 @@ const char* ExpressionConstant::getOpName() const { /* ---------------------- ExpressionDateToString ----------------------- */ REGISTER_EXPRESSION(dateToString, ExpressionDateToString::parse); -intrusive_ptr<Expression> ExpressionDateToString::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionDateToString::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { verify(str::equals(expr.fieldName(), "$dateToString")); uassert(18629, "$dateToString only supports an object as its argument", expr.type() == Object); @@ -939,11 +960,14 @@ intrusive_ptr<Expression> ExpressionDateToString::parse(BSONElement expr, validateFormat(format); - return new ExpressionDateToString(format, parseOperand(dateElem, vps)); + return new ExpressionDateToString(expCtx, format, parseOperand(expCtx, dateElem, vps)); } -ExpressionDateToString::ExpressionDateToString(const string& format, intrusive_ptr<Expression> date) - : _format(format), _date(date) {} +ExpressionDateToString::ExpressionDateToString( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& format, + intrusive_ptr<Expression> date) + : Expression(expCtx), _format(format), _date(date) {} intrusive_ptr<Expression> ExpressionDateToString::optimize() { _date = _date->optimize(); @@ -1103,10 +1127,6 @@ void ExpressionDateToString::addDependencies(DepsTracker* deps) const { _date->addDependencies(deps); } -void ExpressionDateToString::doInjectExpressionContext() { - _date->injectExpressionContext(getExpressionContext()); -} - /* ---------------------- ExpressionDayOfMonth ------------------------- */ Value ExpressionDayOfMonth::evaluateInternal(Variables* vars) const { @@ -1198,16 +1218,20 @@ const char* ExpressionExp::getOpName() const { /* ---------------------- ExpressionObject --------------------------- */ -ExpressionObject::ExpressionObject(vector<pair<string, intrusive_ptr<Expression>>>&& expressions) - : _expressions(std::move(expressions)) {} +ExpressionObject::ExpressionObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + vector<pair<string, intrusive_ptr<Expression>>>&& expressions) + : Expression(expCtx), _expressions(std::move(expressions)) {} intrusive_ptr<ExpressionObject> ExpressionObject::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<pair<string, intrusive_ptr<Expression>>>&& expressions) { - return new ExpressionObject(std::move(expressions)); + return new ExpressionObject(expCtx, std::move(expressions)); } -intrusive_ptr<ExpressionObject> ExpressionObject::parse(BSONObj obj, - const VariablesParseState& vps) { +intrusive_ptr<ExpressionObject> ExpressionObject::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { // Make sure we don't have any duplicate field names. stdx::unordered_set<string> specifiedFields; @@ -1223,10 +1247,10 @@ intrusive_ptr<ExpressionObject> ExpressionObject::parse(BSONObj obj, << obj.toString(), specifiedFields.find(fieldName) == specifiedFields.end()); specifiedFields.insert(fieldName); - expressions.emplace_back(fieldName, parseOperand(elem, vps)); + expressions.emplace_back(fieldName, parseOperand(expCtx, elem, vps)); } - return new ExpressionObject{std::move(expressions)}; + return new ExpressionObject{expCtx, std::move(expressions)}; } intrusive_ptr<Expression> ExpressionObject::optimize() { @@ -1258,22 +1282,19 @@ Value ExpressionObject::serialize(bool explain) const { return outputDoc.freezeToValue(); } -void ExpressionObject::doInjectExpressionContext() { - for (auto&& pair : _expressions) { - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* --------------------- ExpressionFieldPath --------------------------- */ // this is the old deprecated version only used by tests not using variables -intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create(const string& fieldPath) { - return new ExpressionFieldPath("CURRENT." + fieldPath, Variables::ROOT_ID); +intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const string& fieldPath) { + return new ExpressionFieldPath(expCtx, "CURRENT." + fieldPath, Variables::ROOT_ID); } // this is the new version that supports every syntax -intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(const string& raw, - const VariablesParseState& vps) { +intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& raw, + const VariablesParseState& vps) { uassert(16873, str::stream() << "FieldPath '" << raw << "' doesn't start with $", raw.c_str()[0] == '$'); // c_str()[0] is always a valid reference. @@ -1287,16 +1308,18 @@ intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(const string& raw, const StringData fieldPath = rawSD.substr(2); // strip off $$ const StringData varName = fieldPath.substr(0, fieldPath.find('.')); Variables::uassertValidNameForUserRead(varName); - return new ExpressionFieldPath(fieldPath.toString(), vps.getVariable(varName)); + return new ExpressionFieldPath(expCtx, fieldPath.toString(), vps.getVariable(varName)); } else { - return new ExpressionFieldPath("CURRENT." + raw.substr(1), // strip the "$" prefix + return new ExpressionFieldPath(expCtx, + "CURRENT." + raw.substr(1), // strip the "$" prefix vps.getVariable("CURRENT")); } } - -ExpressionFieldPath::ExpressionFieldPath(const string& theFieldPath, Variables::Id variable) - : _fieldPath(theFieldPath), _variable(variable) {} +ExpressionFieldPath::ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& theFieldPath, + Variables::Id variable) + : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {} intrusive_ptr<Expression> ExpressionFieldPath::optimize() { /* nothing can be done for these */ @@ -1384,8 +1407,10 @@ Value ExpressionFieldPath::serialize(bool explain) const { /* ------------------------- ExpressionFilter ----------------------------- */ REGISTER_EXPRESSION(filter, ExpressionFilter::parse); -intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionFilter::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$filter")); uassert(28646, "$filter only supports an object as its argument", expr.type() == Object); @@ -1411,7 +1436,7 @@ intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, uassert(28650, "Missing 'cond' parameter to $filter", !condElem.eoo()); // Parse "input", only has outer variables. - intrusive_ptr<Expression> input = parseOperand(inputElem, vpsIn); + intrusive_ptr<Expression> input = parseOperand(expCtx, inputElem, vpsIn); // Parse "as". VariablesParseState vpsSub(vpsIn); // vpsSub gets our variable, vpsIn doesn't. @@ -1423,16 +1448,19 @@ intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, Variables::Id varId = vpsSub.defineVariable(varName); // Parse "cond", has access to "as" variable. - intrusive_ptr<Expression> cond = parseOperand(condElem, vpsSub); + intrusive_ptr<Expression> cond = parseOperand(expCtx, condElem, vpsSub); - return new ExpressionFilter(std::move(varName), varId, std::move(input), std::move(cond)); + return new ExpressionFilter( + expCtx, std::move(varName), varId, std::move(input), std::move(cond)); } -ExpressionFilter::ExpressionFilter(string varName, +ExpressionFilter::ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx, + string varName, Variables::Id varId, intrusive_ptr<Expression> input, intrusive_ptr<Expression> filter) - : _varName(std::move(varName)), + : Expression(expCtx), + _varName(std::move(varName)), _varId(varId), _input(std::move(input)), _filter(std::move(filter)) {} @@ -1483,11 +1511,6 @@ void ExpressionFilter::addDependencies(DepsTracker* deps) const { _filter->addDependencies(deps); } -void ExpressionFilter::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _filter->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionFloor -------------------------- */ Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { @@ -1512,7 +1535,10 @@ const char* ExpressionFloor::getOpName() const { /* ------------------------- ExpressionLet ----------------------------- */ REGISTER_EXPRESSION(let, ExpressionLet::parse); -intrusive_ptr<Expression> ExpressionLet::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionLet::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$let")); uassert(16874, "$let only supports an object as its argument", expr.type() == Object); @@ -1543,17 +1569,20 @@ intrusive_ptr<Expression> ExpressionLet::parse(BSONElement expr, const Variables Variables::uassertValidNameForUserWrite(varName); Variables::Id id = vpsSub.defineVariable(varName); - vars[id] = NameAndExpression(varName, parseOperand(varElem, vpsIn)); // only has outer vars + vars[id] = NameAndExpression(varName, + parseOperand(expCtx, varElem, vpsIn)); // only has outer vars } // parse "in" - intrusive_ptr<Expression> subExpression = parseOperand(inElem, vpsSub); // has our vars + intrusive_ptr<Expression> subExpression = parseOperand(expCtx, inElem, vpsSub); // has our vars - return new ExpressionLet(vars, subExpression); + return new ExpressionLet(expCtx, vars, subExpression); } -ExpressionLet::ExpressionLet(const VariableMap& vars, intrusive_ptr<Expression> subExpression) - : _variables(vars), _subExpression(subExpression) {} +ExpressionLet::ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const VariableMap& vars, + intrusive_ptr<Expression> subExpression) + : Expression(expCtx), _variables(vars), _subExpression(subExpression) {} intrusive_ptr<Expression> ExpressionLet::optimize() { if (_variables.empty()) { @@ -1603,18 +1632,14 @@ void ExpressionLet::addDependencies(DepsTracker* deps) const { _subExpression->addDependencies(deps); } -void ExpressionLet::doInjectExpressionContext() { - _subExpression->injectExpressionContext(getExpressionContext()); - for (auto&& variable : _variables) { - variable.second.expression->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionMap ----------------------------- */ REGISTER_EXPRESSION(map, ExpressionMap::parse); -intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionMap::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$map")); uassert(16878, "$map only supports an object as its argument", expr.type() == Object); @@ -1641,7 +1666,8 @@ intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const Variables uassert(16882, "Missing 'in' parameter to $map", !inElem.eoo()); // parse "input" - intrusive_ptr<Expression> input = parseOperand(inputElem, vpsIn); // only has outer vars + intrusive_ptr<Expression> input = + parseOperand(expCtx, inputElem, vpsIn); // only has outer vars // parse "as" VariablesParseState vpsSub(vpsIn); // vpsSub gets our vars, vpsIn doesn't. @@ -1653,16 +1679,18 @@ intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const Variables Variables::Id varId = vpsSub.defineVariable(varName); // parse "in" - intrusive_ptr<Expression> in = parseOperand(inElem, vpsSub); // has access to map variable + intrusive_ptr<Expression> in = + parseOperand(expCtx, inElem, vpsSub); // has access to map variable - return new ExpressionMap(varName, varId, input, in); + return new ExpressionMap(expCtx, varName, varId, input, in); } -ExpressionMap::ExpressionMap(const string& varName, +ExpressionMap::ExpressionMap(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& varName, Variables::Id varId, intrusive_ptr<Expression> input, intrusive_ptr<Expression> each) - : _varName(varName), _varId(varId), _input(input), _each(each) {} + : Expression(expCtx), _varName(varName), _varId(varId), _input(input), _each(each) {} intrusive_ptr<Expression> ExpressionMap::optimize() { // TODO handle when _input is constant @@ -1711,27 +1739,26 @@ void ExpressionMap::addDependencies(DepsTracker* deps) const { _each->addDependencies(deps); } -void ExpressionMap::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _each->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionMeta ----------------------------- */ REGISTER_EXPRESSION(meta, ExpressionMeta::parse); -intrusive_ptr<Expression> ExpressionMeta::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionMeta::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { uassert(17307, "$meta only supports string arguments", expr.type() == String); if (expr.valueStringData() == "textScore") { - return new ExpressionMeta(MetaType::TEXT_SCORE); + return new ExpressionMeta(expCtx, MetaType::TEXT_SCORE); } else if (expr.valueStringData() == "randVal") { - return new ExpressionMeta(MetaType::RAND_VAL); + return new ExpressionMeta(expCtx, MetaType::RAND_VAL); } else { uasserted(17308, "Unsupported argument to $meta: " + expr.String()); } } -ExpressionMeta::ExpressionMeta(MetaType metaType) : _metaType(metaType) {} +ExpressionMeta::ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx, + MetaType metaType) + : Expression(expCtx), _metaType(metaType) {} Value ExpressionMeta::serialize(bool explain) const { switch (_metaType) { @@ -2399,12 +2426,6 @@ Value ExpressionNary::serialize(bool explain) const { return Value(DOC(getOpName() << array)); } -void ExpressionNary::doInjectExpressionContext() { - for (auto&& operand : vpOperand) { - operand->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionNot ----------------------------- */ Value ExpressionNot::evaluateInternal(Variables* vars) const { @@ -2491,8 +2512,9 @@ const char* ExpressionOr::getOpName() const { /* ----------------------- ExpressionPow ---------------------------- */ -intrusive_ptr<Expression> ExpressionPow::create(Value base, Value exp) { - intrusive_ptr<ExpressionPow> expr(new ExpressionPow()); +intrusive_ptr<Expression> ExpressionPow::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp) { + intrusive_ptr<ExpressionPow> expr(new ExpressionPow(expCtx)); expr->vpOperand.push_back( ExpressionConstant::create(expr->getExpressionContext(), std::move(base))); expr->vpOperand.push_back( @@ -2723,14 +2745,16 @@ const char* ExpressionRange::getOpName() const { /* ------------------------ ExpressionReduce ------------------------------ */ REGISTER_EXPRESSION(reduce, ExpressionReduce::parse); -intrusive_ptr<Expression> ExpressionReduce::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionReduce::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40075, str::stream() << "$reduce requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce()); + intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce(expCtx)); // vpsSub is used only to parse 'in', which must have access to $$this and $$value. VariablesParseState vpsSub(vps); @@ -2741,11 +2765,11 @@ intrusive_ptr<Expression> ExpressionReduce::parse(BSONElement expr, auto field = elem.fieldNameStringData(); if (field == "input") { - reduce->_input = parseOperand(elem, vps); + reduce->_input = parseOperand(expCtx, elem, vps); } else if (field == "initialValue") { - reduce->_initial = parseOperand(elem, vps); + reduce->_initial = parseOperand(expCtx, elem, vps); } else if (field == "in") { - reduce->_in = parseOperand(elem, vpsSub); + reduce->_in = parseOperand(expCtx, elem, vpsSub); } else { uasserted(40076, str::stream() << "$reduce found an unknown argument: " << field); } @@ -2802,12 +2826,6 @@ Value ExpressionReduce::serialize(bool explain) const { {"in", _in->serialize(explain)}}}}); } -void ExpressionReduce::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _initial->injectExpressionContext(getExpressionContext()); - _in->injectExpressionContext(getExpressionContext()); -} - /* ------------------------ ExpressionReverseArray ------------------------ */ Value ExpressionReverseArray::evaluateInternal(Variables* vars) const { @@ -3031,8 +3049,10 @@ Value ExpressionSetIsSubset::evaluateInternal(Variables* vars) const { */ class ExpressionSetIsSubset::Optimized : public ExpressionSetIsSubset { public: - Optimized(const ValueSet& cachedRhsSet, const ExpressionVector& operands) - : _cachedRhsSet(cachedRhsSet) { + Optimized(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const ValueSet& cachedRhsSet, + const ExpressionVector& operands) + : ExpressionSetIsSubset(expCtx), _cachedRhsSet(cachedRhsSet) { vpOperand = operands; } @@ -3068,9 +3088,10 @@ intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() { << typeName(rhs.getType()), rhs.isArray()); - intrusive_ptr<Expression> optimizedWithConstant(new Optimized( - arrayToSet(rhs, getExpressionContext()->getValueComparator()), vpOperand)); - optimizedWithConstant->injectExpressionContext(getExpressionContext()); + intrusive_ptr<Expression> optimizedWithConstant( + new Optimized(this->getExpressionContext(), + arrayToSet(rhs, getExpressionContext()->getValueComparator()), + vpOperand)); return optimizedWithConstant; } return optimized; @@ -3579,14 +3600,16 @@ Value ExpressionSwitch::evaluateInternal(Variables* vars) const { return _default->evaluateInternal(vars); } -boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, - const VariablesParseState& vps) { +boost::intrusive_ptr<Expression> ExpressionSwitch::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40060, str::stream() << "$switch requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch()); + intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch(expCtx)); for (auto&& elem : expr.Obj()) { auto field = elem.fieldNameStringData(); @@ -3610,9 +3633,9 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, auto branchField = branchElement.fieldNameStringData(); if (branchField == "case") { - branchExpression.first = parseOperand(branchElement, vps); + branchExpression.first = parseOperand(expCtx, branchElement, vps); } else if (branchField == "then") { - branchExpression.second = parseOperand(branchElement, vps); + branchExpression.second = parseOperand(expCtx, branchElement, vps); } else { uasserted(40063, str::stream() << "$switch found an unknown argument to a branch: " @@ -3631,7 +3654,7 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, } } else if (field == "default") { // Optional, arbitrary expression. - expression->_default = parseOperand(elem, vps); + expression->_default = parseOperand(expCtx, elem, vps); } else { uasserted(40067, str::stream() << "$switch found an unknown argument: " << field); } @@ -3686,17 +3709,6 @@ Value ExpressionSwitch::serialize(bool explain) const { return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}}); } -void ExpressionSwitch::doInjectExpressionContext() { - if (_default) { - _default->injectExpressionContext(getExpressionContext()); - } - - for (auto&& pair : _branches) { - pair.first->injectExpressionContext(getExpressionContext()); - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionToLower ----------------------------- */ Value ExpressionToLower::evaluateInternal(Variables* vars) const { @@ -3956,13 +3968,16 @@ const char* ExpressionYear::getOpName() const { /* -------------------------- ExpressionZip ------------------------------ */ REGISTER_EXPRESSION(zip, ExpressionZip::parse); -intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionZip::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(34460, str::stream() << "$zip only supports an object as an argument, found " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionZip> newZip(new ExpressionZip()); + intrusive_ptr<ExpressionZip> newZip(new ExpressionZip(expCtx)); for (auto&& elem : expr.Obj()) { const auto field = elem.fieldNameStringData(); @@ -3972,7 +3987,7 @@ intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_inputs.push_back(parseOperand(subExpr, vps)); + newZip->_inputs.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "defaults") { uassert(34462, @@ -3980,7 +3995,7 @@ intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_defaults.push_back(parseOperand(subExpr, vps)); + newZip->_defaults.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "useLongestLength") { uassert(34463, @@ -4123,14 +4138,4 @@ void ExpressionZip::addDependencies(DepsTracker* deps) const { }); } -void ExpressionZip::doInjectExpressionContext() { - for (auto&& expr : _inputs) { - expr->injectExpressionContext(getExpressionContext()); - } - - for (auto&& expr : _defaults) { - expr->injectExpressionContext(getExpressionContext()); - } -} - } // namespace mongo diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 00e771784c2..ef9a3ae3f19 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -172,8 +172,8 @@ private: class Expression : public IntrusiveCounterUnsigned { public: - using Parser = - stdx::function<boost::intrusive_ptr<Expression>(BSONElement, const VariablesParseState&)>; + using Parser = stdx::function<boost::intrusive_ptr<Expression>( + const boost::intrusive_ptr<ExpressionContext>&, BSONElement, const VariablesParseState&)>; virtual ~Expression(){}; @@ -235,8 +235,10 @@ public: * Calls parseExpression() on any sub-document (including possibly the entire document) which * consists of a single field name starting with a '$'. */ - static boost::intrusive_ptr<Expression> parseObject(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONObj which has already been determined to be a functional expression. @@ -244,8 +246,10 @@ public: * Throws an error if 'obj' does not contain exactly one field, or if that field's name does not * match a registered expression name. */ - static boost::intrusive_ptr<Expression> parseExpression(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONElement which is an argument to an Expression. @@ -254,8 +258,10 @@ public: * parseObject(), ExpressionFieldPath::parse(), ExpressionArray::parse(), or * ExpressionConstant::parse() as necessary. */ - static boost::intrusive_ptr<Expression> parseOperand(BSONElement exprElement, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseOperand( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps); /* Produce a field path std::string with the field prefix removed. @@ -282,24 +288,10 @@ public: */ static void registerExpression(std::string key, Parser parser); - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Expression. - * Construction of expressions is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: - typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector; + Expression(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} - /** - * Expressions which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} + typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector; const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { return _expCtx; @@ -344,12 +336,13 @@ public: /// Allow subclasses the opportunity to validate arguments at parse time. virtual void validateArguments(const ExpressionVector& args) const {} - static ExpressionVector parseArguments(BSONElement bsonExpr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static ExpressionVector parseArguments(const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); protected: - ExpressionNary() {} + explicit ExpressionNary(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} ExpressionVector vpOperand; }; @@ -358,19 +351,29 @@ protected: template <typename SubClass> class ExpressionNaryBase : public ExpressionNary { public: - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps) { - boost::intrusive_ptr<ExpressionNaryBase> expr = new SubClass(); - ExpressionVector args = parseArguments(bsonExpr, vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps) { + boost::intrusive_ptr<ExpressionNaryBase> expr = new SubClass(expCtx); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } + +protected: + explicit ExpressionNaryBase(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNary(expCtx) {} }; /// Inherit from this class if your expression takes a variable number of arguments. template <typename SubClass> -class ExpressionVariadic : public ExpressionNaryBase<SubClass> {}; +class ExpressionVariadic : public ExpressionNaryBase<SubClass> { +public: + explicit ExpressionVariadic(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} +}; /** * Inherit from this class if your expression can take a range of arguments, e.g. if it has some @@ -379,6 +382,9 @@ class ExpressionVariadic : public ExpressionNaryBase<SubClass> {}; template <typename SubClass, int MinArgs, int MaxArgs> class ExpressionRangedArity : public ExpressionNaryBase<SubClass> { public: + explicit ExpressionRangedArity(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(28667, mongoutils::str::stream() << "Expression " << this->getOpName() @@ -397,6 +403,9 @@ public: template <typename SubClass, int NArgs> class ExpressionFixedArity : public ExpressionNaryBase<SubClass> { public: + explicit ExpressionFixedArity(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(16020, mongoutils::str::stream() << "Expression " << this->getOpName() << " takes exactly " @@ -416,9 +425,11 @@ template <typename Accumulator> class ExpressionFromAccumulator : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> { public: + explicit ExpressionFromAccumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionFromAccumulator<Accumulator>>(expCtx) {} + Value evaluateInternal(Variables* vars) const final { - Accumulator accum; - accum.injectExpressionContext(this->getExpressionContext()); + Accumulator accum(this->getExpressionContext()); const size_t n = this->vpOperand.size(); // If a single array arg is given, loop through it passing each member to the accumulator. // If a single, non-array arg is given, pass it directly to the accumulator. @@ -446,15 +457,15 @@ public: if (this->vpOperand.size() == 1) { return false; } - return Accumulator().isAssociative(); + return Accumulator(this->getExpressionContext()).isAssociative(); } bool isCommutative() const final { - return Accumulator().isCommutative(); + return Accumulator(this->getExpressionContext()).isCommutative(); } const char* getOpName() const final { - return Accumulator().getOpName(); + return Accumulator(this->getExpressionContext()).getOpName(); } }; @@ -464,6 +475,9 @@ public: template <typename SubClass> class ExpressionSingleNumericArg : public ExpressionFixedArity<SubClass, 1> { public: + explicit ExpressionSingleNumericArg(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<SubClass, 1>(expCtx) {} + virtual ~ExpressionSingleNumericArg() {} Value evaluateInternal(Variables* vars) const final { @@ -484,6 +498,10 @@ public: class ExpressionAbs final : public ExpressionSingleNumericArg<ExpressionAbs> { +public: + explicit ExpressionAbs(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionAbs>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -491,6 +509,9 @@ class ExpressionAbs final : public ExpressionSingleNumericArg<ExpressionAbs> { class ExpressionAdd final : public ExpressionVariadic<ExpressionAdd> { public: + explicit ExpressionAdd(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionAdd>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -506,6 +527,9 @@ public: class ExpressionAllElementsTrue final : public ExpressionFixedArity<ExpressionAllElementsTrue, 1> { public: + explicit ExpressionAllElementsTrue(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionAllElementsTrue, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -513,6 +537,9 @@ public: class ExpressionAnd final : public ExpressionVariadic<ExpressionAnd> { public: + explicit ExpressionAnd(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionAnd>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -529,6 +556,9 @@ public: class ExpressionAnyElementTrue final : public ExpressionFixedArity<ExpressionAnyElementTrue, 1> { public: + explicit ExpressionAnyElementTrue(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionAnyElementTrue, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -536,7 +566,9 @@ public: class ExpressionArray final : public ExpressionVariadic<ExpressionArray> { public: - // virtuals from ExpressionNary + explicit ExpressionArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionArray>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; Value serialize(bool explain) const final; const char* getOpName() const final; @@ -545,6 +577,9 @@ public: class ExpressionArrayElemAt final : public ExpressionFixedArity<ExpressionArrayElemAt, 2> { public: + explicit ExpressionArrayElemAt(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionArrayElemAt, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -552,6 +587,9 @@ public: class ExpressionCeil final : public ExpressionSingleNumericArg<ExpressionCeil> { public: + explicit ExpressionCeil(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionCeil>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -568,10 +606,9 @@ public: const boost::intrusive_ptr<ExpressionContext>& expCtx, const boost::intrusive_ptr<Expression>& pExpression); - void doInjectExpressionContext() final; - private: - explicit ExpressionCoerceToBool(const boost::intrusive_ptr<Expression>& pExpression); + ExpressionCoerceToBool(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const boost::intrusive_ptr<Expression>& pExpression); boost::intrusive_ptr<Expression> pExpression; }; @@ -593,16 +630,20 @@ public: CMP = 6, // return -1, 0, 1 for a < b, a == b, a > b }; + ExpressionCompare(const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp) + : ExpressionFixedArity<ExpressionCompare, 2>(expCtx), cmpOp(cmpOp) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp cmpOp); - - explicit ExpressionCompare(CmpOp cmpOp); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp cmpOp); static boost::intrusive_ptr<ExpressionCompare> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp, const boost::intrusive_ptr<Expression>& exprLeft, const boost::intrusive_ptr<Expression>& exprRight); @@ -614,6 +655,9 @@ private: class ExpressionConcat final : public ExpressionVariadic<ExpressionConcat> { public: + explicit ExpressionConcat(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionConcat>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -625,6 +669,9 @@ public: class ExpressionConcatArrays final : public ExpressionVariadic<ExpressionConcatArrays> { public: + explicit ExpressionConcatArrays(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionConcatArrays>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -635,13 +682,19 @@ public: class ExpressionCond final : public ExpressionFixedArity<ExpressionCond, 3> { - typedef ExpressionFixedArity<ExpressionCond, 3> Base; - public: + explicit ExpressionCond(const boost::intrusive_ptr<ExpressionContext>& expCtx) : Base(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); + +private: + typedef ExpressionFixedArity<ExpressionCond, 3> Base; }; @@ -657,8 +710,10 @@ public: static boost::intrusive_ptr<ExpressionConstant> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue); - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); /* Get the constant value represented by this Expression. @@ -670,7 +725,7 @@ public: } private: - explicit ExpressionConstant(const Value& pValue); + ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue); Value pValue; }; @@ -682,12 +737,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionDateToString(const std::string& format, // the format string + ExpressionDateToString(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& format, // the format string boost::intrusive_ptr<Expression> date); // the date to format // Will uassert on invalid data @@ -705,6 +762,9 @@ private: class ExpressionDayOfMonth final : public ExpressionFixedArity<ExpressionDayOfMonth, 1> { public: + explicit ExpressionDayOfMonth(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfMonth, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -716,6 +776,9 @@ public: class ExpressionDayOfWeek final : public ExpressionFixedArity<ExpressionDayOfWeek, 1> { public: + explicit ExpressionDayOfWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -728,6 +791,9 @@ public: class ExpressionDayOfYear final : public ExpressionFixedArity<ExpressionDayOfYear, 1> { public: + explicit ExpressionDayOfYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -740,12 +806,19 @@ public: class ExpressionDivide final : public ExpressionFixedArity<ExpressionDivide, 2> { public: + explicit ExpressionDivide(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDivide, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionExp final : public ExpressionSingleNumericArg<ExpressionExp> { +public: + explicit ExpressionExp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionExp>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -771,18 +844,23 @@ public: indicator @returns the newly created field path expression */ - static boost::intrusive_ptr<ExpressionFieldPath> create(const std::string& fieldPath); + static boost::intrusive_ptr<ExpressionFieldPath> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const std::string& fieldPath); /// Like create(), but works with the raw std::string from the user with the "$" prefixes. - static boost::intrusive_ptr<ExpressionFieldPath> parse(const std::string& raw, - const VariablesParseState& vps); + static boost::intrusive_ptr<ExpressionFieldPath> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& raw, + const VariablesParseState& vps); const FieldPath& getFieldPath() const { return _fieldPath; } private: - ExpressionFieldPath(const std::string& fieldPath, Variables::Id variable); + ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& fieldPath, + Variables::Id variable); /* Internal implementation of evaluateInternal(), used recursively. @@ -814,12 +892,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionFilter(std::string varName, + ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::string varName, Variables::Id varId, boost::intrusive_ptr<Expression> input, boost::intrusive_ptr<Expression> filter); @@ -837,6 +917,9 @@ private: class ExpressionFloor final : public ExpressionSingleNumericArg<ExpressionFloor> { public: + explicit ExpressionFloor(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionFloor>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -844,6 +927,9 @@ public: class ExpressionHour final : public ExpressionFixedArity<ExpressionHour, 1> { public: + explicit ExpressionHour(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionHour, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -855,6 +941,9 @@ public: class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> { public: + explicit ExpressionIfNull(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -862,6 +951,9 @@ public: class ExpressionIn final : public ExpressionFixedArity<ExpressionIn, 2> { public: + explicit ExpressionIn(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIn, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -869,6 +961,9 @@ public: class ExpressionIndexOfArray final : public ExpressionRangedArity<ExpressionIndexOfArray, 2, 4> { public: + explicit ExpressionIndexOfArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfArray, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -876,6 +971,9 @@ public: class ExpressionIndexOfBytes final : public ExpressionRangedArity<ExpressionIndexOfBytes, 2, 4> { public: + explicit ExpressionIndexOfBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfBytes, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -886,6 +984,9 @@ public: */ class ExpressionIndexOfCP final : public ExpressionRangedArity<ExpressionIndexOfCP, 2, 4> { public: + explicit ExpressionIndexOfCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfCP, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -898,9 +999,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); struct NameAndExpression { NameAndExpression() {} @@ -914,23 +1016,37 @@ public: typedef std::map<Variables::Id, NameAndExpression> VariableMap; private: - ExpressionLet(const VariableMap& vars, boost::intrusive_ptr<Expression> subExpression); + ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const VariableMap& vars, + boost::intrusive_ptr<Expression> subExpression); VariableMap _variables; boost::intrusive_ptr<Expression> _subExpression; }; class ExpressionLn final : public ExpressionSingleNumericArg<ExpressionLn> { +public: + explicit ExpressionLn(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionLn>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; class ExpressionLog final : public ExpressionFixedArity<ExpressionLog, 2> { +public: + explicit ExpressionLog(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionLog, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionLog10 final : public ExpressionSingleNumericArg<ExpressionLog10> { +public: + explicit ExpressionLog10(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionLog10>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -942,12 +1058,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: ExpressionMap( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const std::string& varName, // name of variable to set Variables::Id varId, // id of variable to set boost::intrusive_ptr<Expression> input, // yields array to iterate @@ -965,7 +1083,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: enum MetaType { @@ -973,13 +1094,16 @@ private: RAND_VAL, }; - ExpressionMeta(MetaType metaType); + ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx, MetaType metaType); MetaType _metaType; }; class ExpressionMillisecond final : public ExpressionFixedArity<ExpressionMillisecond, 1> { public: + explicit ExpressionMillisecond(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMillisecond, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -989,6 +1113,9 @@ public: class ExpressionMinute final : public ExpressionFixedArity<ExpressionMinute, 1> { public: + explicit ExpressionMinute(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMinute, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1000,6 +1127,9 @@ public: class ExpressionMod final : public ExpressionFixedArity<ExpressionMod, 2> { public: + explicit ExpressionMod(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMod, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1007,6 +1137,9 @@ public: class ExpressionMultiply final : public ExpressionVariadic<ExpressionMultiply> { public: + explicit ExpressionMultiply(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionMultiply>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1022,6 +1155,9 @@ public: class ExpressionMonth final : public ExpressionFixedArity<ExpressionMonth, 1> { public: + explicit ExpressionMonth(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMonth, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1034,6 +1170,9 @@ public: class ExpressionNot final : public ExpressionFixedArity<ExpressionNot, 1> { public: + explicit ExpressionNot(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionNot, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1055,13 +1194,16 @@ public: Value serialize(bool explain) const final; static boost::intrusive_ptr<ExpressionObject> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions); /** * Parses and constructs an ExpressionObject from 'obj'. */ - static boost::intrusive_ptr<ExpressionObject> parse(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<ExpressionObject> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * This ExpressionObject must outlive the returned vector. @@ -1071,10 +1213,9 @@ public: return _expressions; } - void doInjectExpressionContext() final; - private: ExpressionObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions); // The mapping from field name to expression within this object. This needs to respect the order @@ -1085,6 +1226,9 @@ private: class ExpressionOr final : public ExpressionVariadic<ExpressionOr> { public: + explicit ExpressionOr(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionOr>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1100,7 +1244,11 @@ public: class ExpressionPow final : public ExpressionFixedArity<ExpressionPow, 2> { public: - static boost::intrusive_ptr<Expression> create(Value base, Value exp); + explicit ExpressionPow(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionPow, 2>(expCtx) {} + + static boost::intrusive_ptr<Expression> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp); private: Value evaluateInternal(Variables* vars) const final; @@ -1109,6 +1257,10 @@ private: class ExpressionRange final : public ExpressionRangedArity<ExpressionRange, 2, 3> { +public: + explicit ExpressionRange(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionRange, 2, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1116,15 +1268,18 @@ class ExpressionRange final : public ExpressionRangedArity<ExpressionRange, 2, 3 class ExpressionReduce final : public Expression { public: + explicit ExpressionReduce(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: boost::intrusive_ptr<Expression> _input; boost::intrusive_ptr<Expression> _initial; @@ -1137,6 +1292,9 @@ private: class ExpressionSecond final : public ExpressionFixedArity<ExpressionSecond, 1> { public: + explicit ExpressionSecond(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSecond, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1148,6 +1306,9 @@ public: class ExpressionSetDifference final : public ExpressionFixedArity<ExpressionSetDifference, 2> { public: + explicit ExpressionSetDifference(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSetDifference, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1155,6 +1316,9 @@ public: class ExpressionSetEquals final : public ExpressionVariadic<ExpressionSetEquals> { public: + explicit ExpressionSetEquals(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetEquals>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; void validateArguments(const ExpressionVector& args) const final; @@ -1163,6 +1327,9 @@ public: class ExpressionSetIntersection final : public ExpressionVariadic<ExpressionSetIntersection> { public: + explicit ExpressionSetIntersection(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetIntersection>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1179,6 +1346,9 @@ public: // Not final, inherited from for optimizations. class ExpressionSetIsSubset : public ExpressionFixedArity<ExpressionSetIsSubset, 2> { public: + explicit ExpressionSetIsSubset(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSetIsSubset, 2>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() override; Value evaluateInternal(Variables* vars) const override; const char* getOpName() const final; @@ -1190,7 +1360,9 @@ private: class ExpressionSetUnion final : public ExpressionVariadic<ExpressionSetUnion> { public: - // intrusive_ptr<Expression> optimize() final; + explicit ExpressionSetUnion(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetUnion>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1206,6 +1378,9 @@ public: class ExpressionSize final : public ExpressionFixedArity<ExpressionSize, 1> { public: + explicit ExpressionSize(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSize, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1213,6 +1388,9 @@ public: class ExpressionReverseArray final : public ExpressionFixedArity<ExpressionReverseArray, 1> { public: + explicit ExpressionReverseArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionReverseArray, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1220,6 +1398,9 @@ public: class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3> { public: + explicit ExpressionSlice(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionSlice, 2, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1227,6 +1408,9 @@ public: class ExpressionIsArray final : public ExpressionFixedArity<ExpressionIsArray, 1> { public: + explicit ExpressionIsArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsArray, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1234,12 +1418,19 @@ public: class ExpressionSplit final : public ExpressionFixedArity<ExpressionSplit, 2> { public: + explicit ExpressionSplit(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSplit, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionSqrt final : public ExpressionSingleNumericArg<ExpressionSqrt> { +public: + explicit ExpressionSqrt(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionSqrt>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1247,6 +1438,9 @@ class ExpressionSqrt final : public ExpressionSingleNumericArg<ExpressionSqrt> { class ExpressionStrcasecmp final : public ExpressionFixedArity<ExpressionStrcasecmp, 2> { public: + explicit ExpressionStrcasecmp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrcasecmp, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1254,6 +1448,9 @@ public: class ExpressionSubstrBytes : public ExpressionFixedArity<ExpressionSubstrBytes, 3> { public: + explicit ExpressionSubstrBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubstrBytes, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const; }; @@ -1261,18 +1458,29 @@ public: class ExpressionSubstrCP final : public ExpressionFixedArity<ExpressionSubstrCP, 3> { public: + explicit ExpressionSubstrCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubstrCP, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenBytes final : public ExpressionFixedArity<ExpressionStrLenBytes, 1> { +public: + explicit ExpressionStrLenBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrLenBytes, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenCP final : public ExpressionFixedArity<ExpressionStrLenCP, 1> { +public: + explicit ExpressionStrLenCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrLenCP, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1280,6 +1488,9 @@ class ExpressionStrLenCP final : public ExpressionFixedArity<ExpressionStrLenCP, class ExpressionSubtract final : public ExpressionFixedArity<ExpressionSubtract, 2> { public: + explicit ExpressionSubtract(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubtract, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1287,15 +1498,18 @@ public: class ExpressionSwitch final : public Expression { public: + explicit ExpressionSwitch(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: using ExpressionPair = std::pair<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>>; @@ -1307,6 +1521,9 @@ private: class ExpressionToLower final : public ExpressionFixedArity<ExpressionToLower, 1> { public: + explicit ExpressionToLower(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionToLower, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1314,6 +1531,9 @@ public: class ExpressionToUpper final : public ExpressionFixedArity<ExpressionToUpper, 1> { public: + explicit ExpressionToUpper(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionToUpper, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1321,6 +1541,9 @@ public: class ExpressionTrunc final : public ExpressionSingleNumericArg<ExpressionTrunc> { public: + explicit ExpressionTrunc(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionTrunc>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1328,6 +1551,9 @@ public: class ExpressionType final : public ExpressionFixedArity<ExpressionType, 1> { public: + explicit ExpressionType(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionType, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1335,6 +1561,9 @@ public: class ExpressionWeek final : public ExpressionFixedArity<ExpressionWeek, 1> { public: + explicit ExpressionWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1344,6 +1573,9 @@ public: class ExpressionIsoWeekYear final : public ExpressionFixedArity<ExpressionIsoWeekYear, 1> { public: + explicit ExpressionIsoWeekYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoWeekYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1353,6 +1585,9 @@ public: class ExpressionIsoDayOfWeek final : public ExpressionFixedArity<ExpressionIsoDayOfWeek, 1> { public: + explicit ExpressionIsoDayOfWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoDayOfWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1362,6 +1597,9 @@ public: class ExpressionIsoWeek final : public ExpressionFixedArity<ExpressionIsoWeek, 1> { public: + explicit ExpressionIsoWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1371,6 +1609,9 @@ public: class ExpressionYear final : public ExpressionFixedArity<ExpressionYear, 1> { public: + explicit ExpressionYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1383,15 +1624,18 @@ public: class ExpressionZip final : public Expression { public: + explicit ExpressionZip(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: bool _useLongestLength = false; ExpressionVector _inputs; diff --git a/src/mongo/db/pipeline/expression_context.cpp b/src/mongo/db/pipeline/expression_context.cpp index 865bf08ec6c..5bccef201ae 100644 --- a/src/mongo/db/pipeline/expression_context.cpp +++ b/src/mongo/db/pipeline/expression_context.cpp @@ -39,27 +39,27 @@ ExpressionContext::ResolvedNamespace::ResolvedNamespace(NamespaceString ns, std::vector<BSONObj> pipeline) : ns(std::move(ns)), pipeline(std::move(pipeline)) {} -ExpressionContext::ExpressionContext(OperationContext* opCtx, const AggregationRequest& request) +ExpressionContext::ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr<CollatorInterface> collator, + StringMap<ResolvedNamespace> resolvedNamespaces) : isExplain(request.isExplain()), inShard(request.isFromRouter()), extSortAllowed(request.shouldAllowDiskUse()), bypassDocumentValidation(request.shouldBypassDocumentValidation()), ns(request.getNamespaceString()), opCtx(opCtx), - collation(request.getCollation()) { - if (!collation.isEmpty()) { - auto statusWithCollator = - CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation); - uassertStatusOK(statusWithCollator.getStatus()); - setCollator(std::move(statusWithCollator.getValue())); - } -} + collation(request.getCollation()), + _collator(std::move(collator)), + _documentComparator(_collator.get()), + _valueComparator(_collator.get()), + _resolvedNamespaces(std::move(resolvedNamespaces)) {} void ExpressionContext::checkForInterrupt() { // This check could be expensive, at least in relative terms, so don't check every time. - if (--interruptCounter == 0) { + if (--_interruptCounter == 0) { opCtx->checkForInterrupt(); - interruptCounter = kInterruptCheckPeriod; + _interruptCounter = kInterruptCheckPeriod; } } @@ -90,9 +90,9 @@ intrusive_ptr<ExpressionContext> ExpressionContext::copyWith(NamespaceString ns) expCtx->setCollator(_collator->clone()); } - expCtx->resolvedNamespaces = resolvedNamespaces; + expCtx->_resolvedNamespaces = _resolvedNamespaces; - // Note that we intentionally skip copying the value of 'interruptCounter' because 'expCtx' is + // Note that we intentionally skip copying the value of '_interruptCounter' because 'expCtx' is // intended to be used for executing a separate aggregation pipeline. return expCtx; diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h index 67114fa52b0..af6d7d199c9 100644 --- a/src/mongo/db/pipeline/expression_context.h +++ b/src/mongo/db/pipeline/expression_context.h @@ -45,7 +45,7 @@ namespace mongo { -struct ExpressionContext : public IntrusiveCounterUnsigned { +class ExpressionContext : public RefCountable { public: struct ResolvedNamespace { ResolvedNamespace() = default; @@ -55,9 +55,14 @@ public: std::vector<BSONObj> pipeline; }; - ExpressionContext() = default; - - ExpressionContext(OperationContext* opCtx, const AggregationRequest& request); + /** + * Constructs an ExpressionContext to be used for Pipeline parsing and evaluation. + * 'resolvedNamespaces' maps collection names (not full namespaces) to ResolvedNamespaces. + */ + ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr<CollatorInterface> collator, + StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces); /** * Used by a pipeline to check for interrupts so that killOp() works. Throws a UserAssertion if @@ -65,8 +70,6 @@ public: */ void checkForInterrupt(); - void setCollator(std::unique_ptr<CollatorInterface> coll); - const CollatorInterface* getCollator() const { return _collator.get(); } @@ -85,6 +88,16 @@ public: */ boost::intrusive_ptr<ExpressionContext> copyWith(NamespaceString ns) const; + /** + * Returns the ResolvedNamespace corresponding to 'nss'. It is an error to call this method on a + * namespace not involved in the pipeline. + */ + const ResolvedNamespace& getResolvedNamespace(const NamespaceString& nss) const { + auto it = _resolvedNamespaces.find(nss.coll()); + invariant(it != _resolvedNamespaces.end()); + return it->second; + }; + bool isExplain = false; bool inShard = false; bool inRouter = false; @@ -100,19 +113,34 @@ public: // collation. BSONObj collation; - StringMap<ResolvedNamespace> resolvedNamespaces; - +protected: static const int kInterruptCheckPeriod = 128; - int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus -private: - // Collator used to compare elements. 'collator' is initialized from 'collation', except in the - // case where 'collation' is empty and there is a collection default collation. + /** + * Should only be used by 'ExpressionContextForTest'. + */ + ExpressionContext() = default; + + /** + * Sets '_collator' and resets '_documentComparator' and '_valueComparator'. + * + * Use with caution - it is illegal to change the collation once a Pipeline has been parsed with + * this ExpressionContext. + */ + void setCollator(std::unique_ptr<CollatorInterface> collator); + + // Collator used for comparisons. std::unique_ptr<CollatorInterface> _collator; // Used for all comparisons of Document/Value during execution of the aggregation operation. + // Must not be changed after parsing a Pipeline with this ExpressionContext. DocumentComparator _documentComparator; ValueComparator _valueComparator; + + // A map from namespace to the resolved namespace, in case any views are involved. + StringMap<ResolvedNamespace> _resolvedNamespaces; + + int _interruptCounter = kInterruptCheckPeriod; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_context_for_test.h b/src/mongo/db/pipeline/expression_context_for_test.h new file mode 100644 index 00000000000..d093cfe2bfa --- /dev/null +++ b/src/mongo/db/pipeline/expression_context_for_test.h @@ -0,0 +1,64 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/expression_context.h" + +namespace mongo { + +/** + * An ExpressionContext that can have state like the collation and resolved namespace map + * manipulated after construction. In contrast, a regular ExpressionContext requires the collation + * and resolved namespaces to be provided on construction and does not allow them to be subsequently + * mutated. + */ +class ExpressionContextForTest : public ExpressionContext { +public: + ExpressionContextForTest() = default; + + ExpressionContextForTest(OperationContext* txn, const AggregationRequest& request) + : ExpressionContext(txn, request, nullptr, {}) {} + + /** + * Changes the collation used by this ExpressionContext. Must not be changed after parsing a + * Pipeline with this ExpressionContext. + */ + void setCollator(std::unique_ptr<CollatorInterface> collator) { + ExpressionContext::setCollator(std::move(collator)); + } + + /** + * Sets the resolved definition for an involved namespace. + */ + void setResolvedNamespace(const NamespaceString& nss, ResolvedNamespace resolvedNamespace) { + _resolvedNamespaces[nss.coll()] = std::move(resolvedNamespace); + } +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 13e96dfef0c..315a2faeac7 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -36,7 +36,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/dbtests/dbtests.h" #include "mongo/unittest/unittest.h" @@ -62,12 +62,11 @@ static void assertExpectedResults(string expression, initializer_list<pair<vector<Value>, Value>> operations) { for (auto&& op : operations) { try { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const BSONObj obj = BSON(expression << Value(op.first)); - auto expression = Expression::parseExpression(obj, vps); - expression->injectExpressionContext(expCtx); + auto expression = Expression::parseExpression(expCtx, obj, vps); Value result = expression->evaluate(Document()); ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); @@ -190,7 +189,10 @@ public: private: Testable(bool isAssociative, bool isCommutative) - : _isAssociative(isAssociative), _isCommutative(isCommutative) {} + : ExpressionNary( + boost::intrusive_ptr<ExpressionContextForTest>(new ExpressionContextForTest())), + _isAssociative(isAssociative), + _isCommutative(isCommutative) {} bool _isAssociative; bool _isCommutative; }; @@ -224,12 +226,13 @@ protected: } void addOperandArrayToExpr(const intrusive_ptr<Testable>& expr, const BSONArray& operands) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); BSONObjIterator i(operands); while (i.more()) { BSONElement element = i.next(); - expr->addOperand(Expression::parseOperand(element, vps)); + expr->addOperand(Expression::parseOperand(expCtx, element, vps)); } } @@ -244,7 +247,7 @@ TEST_F(ExpressionNaryTest, AddedConstantOperandIsSerialized) { } TEST_F(ExpressionNaryTest, AddedFieldPathOperandIsSerialized) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertContents(_notAssociativeNorCommutative, BSON_ARRAY("$ab.c")); } @@ -258,7 +261,7 @@ TEST_F(ExpressionNaryTest, ValidateConstantExpressionDependency) { } TEST_F(ExpressionNaryTest, ValidateFieldPathExpressionDependency) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("ab.c")); } @@ -267,10 +270,12 @@ TEST_F(ExpressionNaryTest, ValidateObjectExpressionDependency) { << "$x" << "q" << "$r")); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - _notAssociativeNorCommutative->addOperand(Expression::parseObject(specElement.Obj(), vps)); + _notAssociativeNorCommutative->addOperand( + Expression::parseObject(expCtx, specElement.Obj(), vps)); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("r" << "x")); @@ -683,7 +688,8 @@ TEST_F(ExpressionNaryTest, FlattenInnerOperandsOptimizationOnCommutativeAndAssoc class ExpressionCeilTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionCeil(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionCeil(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -741,7 +747,8 @@ TEST_F(ExpressionCeilTest, NullArg) { class ExpressionFloorTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionFloor(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionFloor(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -856,7 +863,8 @@ TEST(ExpressionReverseArrayTest, ReturnsNullWithNullishInput) { class ExpressionTruncTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionTrunc(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionTrunc(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -917,7 +925,8 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); populateOperands(expression); ASSERT_BSONOBJ_EQ(expectedResult(), toBson(expression->evaluate(Document()))); } @@ -932,7 +941,8 @@ protected: class NullDocument { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(2))); ASSERT_BSONOBJ_EQ(BSON("" << 2), toBson(expression->evaluate(Document()))); } @@ -950,7 +960,8 @@ class NoOperands : public ExpectedResultBase { class String { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value("a"_sd))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -960,7 +971,8 @@ public: class Bool { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(true))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -1193,13 +1205,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -1217,13 +1228,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -1515,7 +1525,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr<Expression> nested = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> nested = ExpressionFieldPath::create(nullptr, "a.b"); intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested); DepsTracker dependencies; expression->addDependencies(&dependencies); @@ -1531,7 +1541,7 @@ class AddToBsonObj { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(fromjson("{field:{$and:['$foo']}}"), toBsonObj(expression)); @@ -1548,7 +1558,7 @@ class AddToBsonArray { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(BSON_ARRAY(fromjson("{$and:['$foo']}")), toBsonArray(expression)); @@ -1577,9 +1587,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(constify(expectedOptimized()), expressionToBson(optimized)); } @@ -1610,9 +1619,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); // Check expression spec round trip. ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); // Check evaluation result. @@ -1648,11 +1656,12 @@ class ParseError { public: virtual ~ParseError() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - ASSERT_THROWS(Expression::parseOperand(specElement, vps), UserException); + ASSERT_THROWS(Expression::parseOperand(expCtx, specElement, vps), UserException); } protected: @@ -1855,9 +1864,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true)); } }; @@ -2000,10 +2008,11 @@ public: void run() { BSONObj spec = BSON("IGNORED_FIELD_NAME" << "foo"); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = ExpressionConstant::parse(specElement, vps); + intrusive_ptr<Expression> expression = ExpressionConstant::parse(expCtx, specElement, vps); assertBinaryEqual(BSON("" << "foo"), toBson(expression->evaluate(Document()))); @@ -2140,7 +2149,7 @@ namespace FieldPath { class Invalid { public: void run() { - ASSERT_THROWS(ExpressionFieldPath::create(""), UserException); + ASSERT_THROWS(ExpressionFieldPath::create(nullptr, ""), UserException); } }; @@ -2148,7 +2157,7 @@ public: class Optimize { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); // An attempt to optimize returns the Expression itself. ASSERT_EQUALS(expression, expression->optimize()); } @@ -2158,7 +2167,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); DepsTracker dependencies; expression->addDependencies(&dependencies); ASSERT_EQUALS(1U, dependencies.fields.size()); @@ -2172,7 +2181,7 @@ public: class Missing { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(Document()))); } }; @@ -2181,7 +2190,7 @@ public: class Present { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{'':123}"), toBson(expression->evaluate(fromBson(BSON("a" << 123))))); } @@ -2191,7 +2200,7 @@ public: class NestedBelowNull { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:null}"))))); } @@ -2201,7 +2210,7 @@ public: class NestedBelowUndefined { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:undefined}"))))); } @@ -2211,7 +2220,7 @@ public: class NestedBelowMissing { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{z:1}"))))); } @@ -2221,7 +2230,7 @@ public: class NestedBelowInt { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << 2))))); } }; @@ -2230,7 +2239,7 @@ public: class NestedValue { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << 55), toBson(expression->evaluate(fromBson(BSON("a" << BSON("b" << 55)))))); } @@ -2240,7 +2249,7 @@ public: class NestedBelowEmptyObject { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << BSONObj()))))); } @@ -2250,7 +2259,7 @@ public: class NestedBelowEmptyArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << BSONArray()), toBson(expression->evaluate(fromBson(BSON("a" << BSONArray()))))); } @@ -2260,7 +2269,7 @@ public: class NestedBelowArrayWithNull { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[null]}"))))); } @@ -2270,7 +2279,7 @@ public: class NestedBelowArrayWithUndefined { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[undefined]}"))))); } @@ -2280,7 +2289,7 @@ public: class NestedBelowArrayWithInt { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[1]}"))))); } @@ -2290,7 +2299,7 @@ public: class NestedWithinArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:9}]}"))))); } @@ -2300,7 +2309,7 @@ public: class MultipleArrayValues { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9,20]}"), toBson(expression->evaluate( fromBson(fromjson("{a:[{b:9},null,undefined,{g:4},{b:20},{}]}"))))); @@ -2311,7 +2320,7 @@ public: class ExpandNestedArrays { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(fromjson("{'':[[1,2],3,[4],[[5]],[6,7]]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:[{c:1},{c:2}]}," "{b:{c:3}}," @@ -2325,7 +2334,7 @@ public: class AddToBsonObj { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(BSON("foo" << "$a.b.c"), BSON("foo" << expression->serialize(false))); @@ -2336,7 +2345,7 @@ public: class AddToBsonArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); BSONArrayBuilder bab; bab << expression->serialize(false); assertBinaryEqual(BSON_ARRAY("$a.b.c"), bab.arr()); @@ -2358,16 +2367,19 @@ Document literal(T&& value) { // TEST(ExpressionObjectParse, ShouldAcceptEmptyObject) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSONObj(), vps); + auto object = ExpressionObject::parse(expCtx, BSONObj(), vps); ASSERT_VALUE_EQ(Value(Document{}), object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << 5 << "b" + auto object = ExpressionObject::parse(expCtx, + BSON("a" << 5 << "b" << "string" << "c" << BSONNULL), @@ -2378,25 +2390,29 @@ TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { } TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("_id" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("_id" << 5), vps); auto expectedResult = Value(Document{{"_id", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a$b" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a$b" << 5), vps); auto expectedResult = Value(Document{{"a$b", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); + auto object = + ExpressionObject::parse(expCtx, fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); auto expectedResult = Value(Document{{"a", Document{{"b", literal(1)}}}, {"c", Document{{"d", Document{{"e", literal(1)}, {"f", literal(1)}}}}}}); @@ -2404,18 +2420,20 @@ TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { } TEST(ExpressionObjectParse, ShouldAcceptArrays) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: [1, 2]}"), vps); + auto object = ExpressionObject::parse(expCtx, fromjson("{a: [1, 2]}"), vps); auto expectedResult = Value(Document{{"a", vector<Value>{Value(literal(1)), Value(literal(2))}}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << BSON("$and" << BSONArray())), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a" << BSON("$and" << BSONArray())), vps); ASSERT_VALUE_EQ(object->serialize(false), Value(Document{{"a", Document{{"$and", BSONArray()}}}})); } @@ -2425,48 +2443,60 @@ TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { // TEST(ExpressionObjectParse, ShouldRejectDottedFieldNames) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("c" << 3 << "a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1 << "c" << 3), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("c" << 3 << "a.b" << 1), vps), + UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1 << "c" << 3), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectDuplicateFieldNames) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << 1), vps), UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), + UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldName) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("$a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON(std::string("a\0b", 3) << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("$a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON(std::string("a\0b", 3) << 1), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldPathAsValue) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" + ASSERT_THROWS(ExpressionObject::parse(expCtx, + BSON("a" << "$field."), vps), UserException); } TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse( - BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), - UserException); + ASSERT_THROWS( + ExpressionObject::parse( + expCtx, BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), + UserException); } // @@ -2474,14 +2504,17 @@ TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { // TEST(ExpressionObjectEvaluate, EmptyObjectShouldEvaluateToEmptyDocument) { - auto object = ExpressionObject::create({}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {}); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { - auto object = ExpressionObject::create({{"a", makeConstant(1)}, {"b", makeConstant(5)}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", makeConstant(1)}, {"b", makeConstant(5)}}); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), @@ -2489,36 +2522,45 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { } TEST(ExpressionObjectEvaluate, OrderOfFieldsInOutputShouldMatchOrderInSpecification) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("a")}, - {"b", ExpressionFieldPath::create("b")}, - {"c", ExpressionFieldPath::create("c")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a")}, + {"b", ExpressionFieldPath::create(expCtx, "b")}, + {"c", ExpressionFieldPath::create(expCtx, "c")}}); ASSERT_VALUE_EQ( Value(Document{{"a", "A"_sd}, {"b", "B"_sd}, {"c", "C"_sd}}), object->evaluate(Document{{"c", "C"_sd}, {"a", "A"_sd}, {"b", "B"_sd}, {"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldRemoveFieldsThatHaveMissingValues) { - auto object = ExpressionObject::create( - {{"a", ExpressionFieldPath::create("a.b")}, {"b", ExpressionFieldPath::create("missing")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a.b")}, + {"b", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document{{"a", 1}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); auto object = ExpressionObject::create( + expCtx, {{"a", ExpressionObject::create( - {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create("_id")}})}}); + expCtx, + {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create(expCtx, "_id")}})}}); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"_sd}}}}), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissing) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("missing")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); - auto objectWithNestedObject = ExpressionObject::create({{"nested", object}}); + auto objectWithNestedObject = ExpressionObject::create(expCtx, {{"nested", object}}); ASSERT_VALUE_EQ(Value(Document{{"nested", Document{}}}), objectWithNestedObject->evaluate(Document())); } @@ -2528,14 +2570,17 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissin // TEST(ExpressionObjectDependencies, ConstantValuesShouldNotBeAddedToDependencies) { - auto object = ExpressionObject::create({{"a", makeConstant(5)}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {{"a", makeConstant(5)}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 0UL); } TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { - auto object = ExpressionObject::create({{"x", ExpressionFieldPath::create("c.d")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"x", ExpressionFieldPath::create(expCtx, "c.d")}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 1UL); @@ -2548,11 +2593,12 @@ TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { TEST(ExpressionObjectOptimizations, OptimizingAnObjectShouldOptimizeSubExpressions) { // Build up the object {a: {$add: [1, 2]}}. + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); auto addExpression = - ExpressionAdd::parse(BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); - auto object = ExpressionObject::create({{"a", addExpression}}); + ExpressionAdd::parse(expCtx, BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); + auto object = ExpressionObject::create(expCtx, {{"a", addExpression}}); ASSERT_EQ(object->getChildExpressions().size(), 1UL); auto optimized = object->optimize(); @@ -2575,11 +2621,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -2597,11 +2644,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -2876,10 +2924,11 @@ namespace Object { * Parses the object given by 'specification', with the options given by 'parseContextOptions'. */ boost::intrusive_ptr<Expression> parseObject(BSONObj specification) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseObject(specification, vps); + return Expression::parseObject(expCtx, specification, vps); }; TEST(ParseObject, ShouldAcceptEmptyObject) { @@ -2910,9 +2959,10 @@ using mongo::Expression; * Parses an expression from the given BSON specification. */ boost::intrusive_ptr<Expression> parseExpression(BSONObj specification) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseExpression(specification, vps); + return Expression::parseExpression(expCtx, specification, vps); } TEST(ParseExpression, ShouldRecognizeConstExpression) { @@ -3016,10 +3066,11 @@ using mongo::Expression; * case the field name would be the name of the expression. */ intrusive_ptr<Expression> parseOperand(BSONObj specification) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = specification.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseOperand(specElement, vps); + return Expression::parseOperand(expCtx, specElement, vps); } TEST(ParseOperand, ShouldRecognizeFieldPath) { @@ -3081,7 +3132,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3092,8 +3143,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr<Expression> expr = + Expression::parseExpression(expCtx, obj, vps); Value result = expr->evaluate(Document()); if (result.getType() == Array) { result = sortSet(result); @@ -3121,8 +3172,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr<Expression> expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); @@ -3399,13 +3449,12 @@ private: return BSON("$strcasecmp" << BSON_ARRAY(b() << a())); } void assertResult(int expectedResult, const BSONObj& spec) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult), toBson(expression->evaluate(Document()))); } @@ -3527,13 +3576,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3635,22 +3683,24 @@ class DropEndingNull : public ExpectedResultBase { namespace SubstrCP { TEST(ExpressionSubstrCPTest, DoesThrowWithBadContinuationByte) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto continuationByte = "\x80\x00"_sd; const auto expr = Expression::parseExpression( - BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); + expCtx, BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } TEST(ExpressionSubstrCPTest, DoesThrowWithInvalidLeadingByte) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto leadingByte = "\xFF\x00"_sd; - const auto expr = - Expression::parseExpression(BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); + const auto expr = Expression::parseExpression( + expCtx, BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } @@ -3792,13 +3842,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3851,13 +3900,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3909,7 +3957,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3920,8 +3968,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr<Expression> expr = + Expression::parseExpression(expCtx, obj, vps); const Value result = expr->evaluate(Document()); if (ValueComparator().evaluate(result != expected)) { string errMsg = str::stream() @@ -3946,8 +3994,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr<Expression> expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); diff --git a/src/mongo/db/pipeline/granularity_rounder.cpp b/src/mongo/db/pipeline/granularity_rounder.cpp index 0941c8b36fa..6071c6644db 100644 --- a/src/mongo/db/pipeline/granularity_rounder.cpp +++ b/src/mongo/db/pipeline/granularity_rounder.cpp @@ -50,11 +50,11 @@ void GranularityRounder::registerGranularityRounder(StringData name, Rounder rou } boost::intrusive_ptr<GranularityRounder> GranularityRounder::getGranularityRounder( - StringData granularity) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData granularity) { auto it = rounderMap.find(granularity); uassert(40257, str::stream() << "Unknown rounding granularity '" << granularity << "'", it != rounderMap.end()); - return it->second(); + return it->second(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder.h b/src/mongo/db/pipeline/granularity_rounder.h index 862c604dcc0..24a8f53de97 100644 --- a/src/mongo/db/pipeline/granularity_rounder.h +++ b/src/mongo/db/pipeline/granularity_rounder.h @@ -32,6 +32,7 @@ #include "mongo/base/init.h" #include "mongo/db/jsobj.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" #include "mongo/util/intrusive_counter.h" @@ -73,7 +74,8 @@ namespace mongo { */ class GranularityRounder : public RefCountable { public: - using Rounder = stdx::function<boost::intrusive_ptr<GranularityRounder>()>; + using Rounder = stdx::function<boost::intrusive_ptr<GranularityRounder>( + const boost::intrusive_ptr<ExpressionContext>&)>; /** * Registers a GranularityRounder with a parsing function so that when a granularity @@ -89,7 +91,8 @@ public: * Retrieves the GranularityRounder for the granularity given by 'granularity', and raises an * error if there is no such granularity registered. */ - static boost::intrusive_ptr<GranularityRounder> getGranularityRounder(StringData granularity); + static boost::intrusive_ptr<GranularityRounder> getGranularityRounder( + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData granularity); /** * Rounds up 'value' to the first value greater than 'value' in the granularity series. If @@ -109,6 +112,16 @@ public: * Returns the name of the granularity series that the GranularityRounder is using for rounding. */ virtual std::string getName() = 0; + +protected: + GranularityRounder(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} + + ExpressionContext* getExpCtx() { + return _expCtx.get(); + } + +private: + boost::intrusive_ptr<ExpressionContext> _expCtx; }; /** @@ -124,8 +137,10 @@ public: * 'baseSeries'. This method requires that baseSeries has at least 2 numbers and is in sorted * order. */ - static boost::intrusive_ptr<GranularityRounder> create(const std::vector<double> baseSeries, - std::string name); + static boost::intrusive_ptr<GranularityRounder> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::vector<double> baseSeries, + std::string name); Value roundUp(Value value); Value roundDown(Value value); @@ -138,7 +153,9 @@ public: const std::vector<double> getSeries() const; private: - GranularityRounderPreferredNumbers(std::vector<double> baseSeries, std::string name); + GranularityRounderPreferredNumbers(const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::vector<double> baseSeries, + std::string name); // '_baseSeries' is the preferred number series that is used for rounding. A preferred numbers // series is infinite, but we represent it with a finite vector of numbers. When rounding, we @@ -153,13 +170,16 @@ private: */ class GranularityRounderPowersOfTwo final : public GranularityRounder { public: - static boost::intrusive_ptr<GranularityRounder> create(); + static boost::intrusive_ptr<GranularityRounder> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); Value roundUp(Value value); Value roundDown(Value value); std::string getName(); private: - GranularityRounderPowersOfTwo() = default; + GranularityRounderPowersOfTwo(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : GranularityRounder(expCtx) {} + std::string _name = "POWERSOF2"; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp index 56d3ae02267..078936e8145 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp @@ -40,8 +40,9 @@ using std::string; REGISTER_GRANULARITY_ROUNDER(POWERSOF2, GranularityRounderPowersOfTwo::create); -intrusive_ptr<GranularityRounder> GranularityRounderPowersOfTwo::create() { - return new GranularityRounderPowersOfTwo(); +intrusive_ptr<GranularityRounder> GranularityRounderPowersOfTwo::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new GranularityRounderPowersOfTwo(expCtx); } namespace { @@ -80,7 +81,7 @@ Value GranularityRounderPowersOfTwo::roundUp(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } Value GranularityRounderPowersOfTwo::roundDown(Value value) { @@ -113,7 +114,7 @@ Value GranularityRounderPowersOfTwo::roundDown(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } string GranularityRounderPowersOfTwo::getName() { diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp index adb7046c0c1..eefcdb749e5 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp @@ -31,6 +31,7 @@ #include "mongo/db/pipeline/granularity_rounder.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -50,7 +51,8 @@ void testEquals(Value actual, Value expected, double delta = DELTA) { } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0.5)), Value(1)); testEquals(rounder->roundUp(Value(1)), Value(2)); @@ -64,7 +66,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 4611686018427387905; // 2^62 + 1 double output = 9223372036854775808.0; // 2^63 @@ -78,7 +81,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); int input = 1073741824; // 2^30 long long output = 2147483648; // 2^31 @@ -92,7 +96,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubleIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); double input = 9223372036854775808.0; // 2^63 long long output = 4611686018427387904; // 2^62 @@ -106,7 +111,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubl } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumberLongIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 2147483648; // 2^31 int output = 1073741824; // 2^30 @@ -120,7 +126,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumber } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.12); Decimal128 output = Decimal128(0.125); @@ -134,7 +141,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpN } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDownNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.13); Decimal128 output = Decimal128(0.125); @@ -148,7 +156,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDow } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(3)), Value(4)); testEquals(rounder->roundUp(Value(5)), Value(8)); @@ -168,7 +177,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOf } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(16)), Value(8)); testEquals(rounder->roundDown(Value(8)), Value(4)); @@ -183,7 +193,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfT } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(10)), Value(8)); testEquals(rounder->roundDown(Value(9)), Value(8)); @@ -201,14 +212,16 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPower } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundZeroToZero) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0)), Value(0)); testEquals(rounder->roundDown(Value(0)), Value(0)); } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -217,7 +230,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value nan = Value(std::nan("NaN")); ASSERT_THROWS_CODE(rounder->roundUp(nan), UserException, 40266); @@ -232,7 +246,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNegativeNumber) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value negativeNumber = Value(-1); ASSERT_THROWS_CODE(rounder->roundUp(negativeNumber), UserException, 40267); diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp index 81a845bed6e..5b62d435a14 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp @@ -95,56 +95,57 @@ const vector<double> e192Series{ } // namespace // Register the GranularityRounders for the Renard number series. -REGISTER_GRANULARITY_ROUNDER(R5, []() { - return GranularityRounderPreferredNumbers::create(r5Series, "R5"); +REGISTER_GRANULARITY_ROUNDER(R5, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r5Series, "R5"); }); -REGISTER_GRANULARITY_ROUNDER(R10, []() { - return GranularityRounderPreferredNumbers::create(r10Series, "R10"); +REGISTER_GRANULARITY_ROUNDER(R10, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r10Series, "R10"); }); -REGISTER_GRANULARITY_ROUNDER(R20, []() { - return GranularityRounderPreferredNumbers::create(r20Series, "R20"); +REGISTER_GRANULARITY_ROUNDER(R20, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r20Series, "R20"); }); -REGISTER_GRANULARITY_ROUNDER(R40, []() { - return GranularityRounderPreferredNumbers::create(r40Series, "R40"); +REGISTER_GRANULARITY_ROUNDER(R40, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r40Series, "R40"); }); -REGISTER_GRANULARITY_ROUNDER(R80, []() { - return GranularityRounderPreferredNumbers::create(r80Series, "R80"); +REGISTER_GRANULARITY_ROUNDER(R80, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r80Series, "R80"); }); -REGISTER_GRANULARITY_ROUNDER_GENERAL("1-2-5", 1_2_5, []() { - return GranularityRounderPreferredNumbers::create(series125, "1-2-5"); -}); +REGISTER_GRANULARITY_ROUNDER_GENERAL( + "1-2-5", 1_2_5, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, series125, "1-2-5"); + }); // Register the GranularityRounders for the E series. -REGISTER_GRANULARITY_ROUNDER(E6, []() { - return GranularityRounderPreferredNumbers::create(e6Series, "E6"); +REGISTER_GRANULARITY_ROUNDER(E6, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e6Series, "E6"); }); -REGISTER_GRANULARITY_ROUNDER(E12, []() { - return GranularityRounderPreferredNumbers::create(e12Series, "E12"); +REGISTER_GRANULARITY_ROUNDER(E12, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e12Series, "E12"); }); -REGISTER_GRANULARITY_ROUNDER(E24, []() { - return GranularityRounderPreferredNumbers::create(e24Series, "E24"); +REGISTER_GRANULARITY_ROUNDER(E24, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e24Series, "E24"); }); -REGISTER_GRANULARITY_ROUNDER(E48, []() { - return GranularityRounderPreferredNumbers::create(e48Series, "E48"); +REGISTER_GRANULARITY_ROUNDER(E48, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e48Series, "E48"); }); -REGISTER_GRANULARITY_ROUNDER(E96, []() { - return GranularityRounderPreferredNumbers::create(e96Series, "E96"); +REGISTER_GRANULARITY_ROUNDER(E96, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e96Series, "E96"); }); -REGISTER_GRANULARITY_ROUNDER(E192, []() { - return GranularityRounderPreferredNumbers::create(e192Series, "E192"); +REGISTER_GRANULARITY_ROUNDER(E192, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e192Series, "E192"); }); -GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers(vector<double> baseSeries, - string name) - : _baseSeries(baseSeries), _name(name) { +GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers( + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<double> baseSeries, string name) + : GranularityRounder(expCtx), _baseSeries(baseSeries), _name(name) { invariant(_baseSeries.size() > 1); invariant(std::is_sorted(_baseSeries.begin(), _baseSeries.end())); } intrusive_ptr<GranularityRounder> GranularityRounderPreferredNumbers::create( - vector<double> baseSeries, string name) { - return new GranularityRounderPreferredNumbers(baseSeries, name); + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<double> baseSeries, string name) { + return new GranularityRounderPreferredNumbers(expCtx, baseSeries, name); } namespace { diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp index 17d10d07977..89df39621ad 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp @@ -32,6 +32,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/util/assert_util.h" namespace mongo { @@ -461,7 +462,8 @@ void testSeriesWrappingAroundDecimal(intrusive_ptr<GranularityRounder> rounder) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpInSeries(rounder); testRoundingUpInSeriesDecimal(rounder); @@ -471,7 +473,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownNumberInSeriesToPreviousNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownInSeries(rounder); testRoundingDownInSeriesDecimal(rounder); @@ -480,7 +483,8 @@ TEST(GranularityRounderPreferredNumbersTest, TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpBetweenSeries(rounder); testRoundingUpBetweenSeriesDecimal(rounder); @@ -489,7 +493,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownBetweenSeries(rounder); testRoundingDownBetweenSeriesDecimal(rounder); @@ -498,7 +503,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeries TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testSeriesWrappingAround(rounder); testSeriesWrappingAroundDecimal(rounder); @@ -507,7 +513,8 @@ TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder rounds zero to zero. testEquals(rounder->roundUp(Value(0)), Value(0)); @@ -520,7 +527,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValues) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -531,7 +539,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValue TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding NaN. Value nan = Value(std::nan("NaN")); @@ -549,7 +558,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNegativeNumber) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a negative number. Value negativeNumber = Value(-1); diff --git a/src/mongo/db/pipeline/lookup_set_cache.h b/src/mongo/db/pipeline/lookup_set_cache.h index d482b267ee8..1b7adbd00a6 100644 --- a/src/mongo/db/pipeline/lookup_set_cache.h +++ b/src/mongo/db/pipeline/lookup_set_cache.h @@ -81,9 +81,12 @@ public: * ValueComparator. This requires instantiating the multi_index_container with comparison and * hasher functions obtained from the comparator. */ - LookupSetCache(ValueComparator valueComparator) - : _valueComparator(std::move(valueComparator)), - _container(makeIndexedContainer(_valueComparator)) {} + explicit LookupSetCache(const ValueComparator& comparator) + : _container(boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), + boost::make_tuple(0, + member<Cached, Value, &Cached::first>(), + comparator.getHasher(), + comparator.getEqualTo()))) {} /** * Insert "value" into the set with key "key". If "key" is already present in the cache, move it @@ -186,29 +189,7 @@ public: return boost::none; } - /** - * Binds the cache to a new comparator that should be used to make all subsequent Value - * comparisons. - * - * TODO SERVER-25535: Remove this method. - */ - void setValueComparator(ValueComparator valueComparator) { - _valueComparator = std::move(valueComparator); - _container = makeIndexedContainer(_valueComparator); - } - private: - IndexedContainer makeIndexedContainer(const ValueComparator& valueComparator) const { - return IndexedContainer( - boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), - boost::make_tuple(0, - member<Cached, Value, &Cached::first>(), - valueComparator.getHasher(), - valueComparator.getEqualTo()))); - } - - ValueComparator _valueComparator; - IndexedContainer _container; size_t _memoryUsage = 0; diff --git a/src/mongo/db/pipeline/lookup_set_cache_test.cpp b/src/mongo/db/pipeline/lookup_set_cache_test.cpp index ce7503d4c6e..54a8b990ae7 100644 --- a/src/mongo/db/pipeline/lookup_set_cache_test.cpp +++ b/src/mongo/db/pipeline/lookup_set_cache_test.cpp @@ -52,9 +52,10 @@ BSONObj intToObj(int value) { return BSON("n" << value); } +const ValueComparator defaultComparator{nullptr}; + TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(1)); cache.insert(Value(0), intToObj(2)); cache.insert(Value(0), intToObj(3)); @@ -70,8 +71,7 @@ TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { } TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -90,8 +90,7 @@ TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { } TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -106,8 +105,7 @@ TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { } TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -120,8 +118,7 @@ TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { } TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -134,8 +131,7 @@ TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { } TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); for (int i = 0; i < 5; i++) { for (int j = 0; j < 5; j++) { @@ -173,7 +169,8 @@ TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { TEST(LookupSetCacheTest, CacheKeysRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), intToObj(1)); cache.insert(Value("FOO"_sd), intToObj(2)); @@ -196,7 +193,8 @@ TEST(LookupSetCacheTest, CacheKeysRespectCollation) { // foreign collection. TEST(LookupSetCacheTest, CachedValuesDontRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), BSON("foo" diff --git a/src/mongo/db/pipeline/parsed_add_fields.cpp b/src/mongo/db/pipeline/parsed_add_fields.cpp index 2bca98f21bd..ae78664a8a5 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields.cpp @@ -38,7 +38,8 @@ namespace mongo { namespace parsed_aggregation_projection { -std::unique_ptr<ParsedAddFields> ParsedAddFields::create(const BSONObj& spec) { +std::unique_ptr<ParsedAddFields> ParsedAddFields::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) { // Verify that we don't have conflicting field paths, etc. Status status = ProjectionSpecValidator::validate(spec); if (!status.isOK()) { @@ -48,17 +49,19 @@ std::unique_ptr<ParsedAddFields> ParsedAddFields::create(const BSONObj& spec) { std::unique_ptr<ParsedAddFields> parsedAddFields = stdx::make_unique<ParsedAddFields>(); // Actually parse the specification. - parsedAddFields->parse(spec); + parsedAddFields->parse(expCtx, spec); return parsedAddFields; } -void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& variablesParseState) { +void ParsedAddFields::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState) { for (auto elem : spec) { auto fieldName = elem.fieldNameStringData(); if (elem.type() == BSONType::Object) { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. } else { // The field name might be a dotted path. If so, we need to keep adding children @@ -72,12 +75,12 @@ void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& vari // It is illegal to construct an empty FieldPath, so the above loop ends one // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -96,7 +99,8 @@ Document ParsedAddFields::applyProjection(Document inputDoc, Variables* vars) co return output.freeze(); } -bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, +bool ParsedAddFields::parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { if (objSpec.firstElementFieldName()[0] == '$') { @@ -104,13 +108,14 @@ bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedAddFields::parseSubObject(const BSONObj& subObj, +void ParsedAddFields::parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node) { for (auto&& elem : subObj) { @@ -123,17 +128,18 @@ void ParsedAddFields::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (!parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { // It was a nested subobject auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_add_fields.h b/src/mongo/db/pipeline/parsed_add_fields.h index 58b8445e727..982fe6a3017 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.h +++ b/src/mongo/db/pipeline/parsed_add_fields.h @@ -58,7 +58,8 @@ public: * Verifies that there are no conflicting paths in the specification. * Overrides the ParsedAggregationProjection's create method. */ - static std::unique_ptr<ParsedAddFields> create(const BSONObj& spec); + static std::unique_ptr<ParsedAddFields> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec); ProjectionType getType() const final { return ProjectionType::kComputed; @@ -67,10 +68,10 @@ public: /** * Parses the addFields specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique<Variables>(idGenerator.getIdCount()); } @@ -87,10 +88,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::SEE_NEXT; @@ -124,7 +121,9 @@ private: /** * Parses 'spec' to determine which fields to add. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -134,7 +133,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -142,7 +142,8 @@ private: * Traverses 'subObj' and parses each field. Adds any computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_add_fields_test.cpp b/src/mongo/db/pipeline/parsed_add_fields_test.cpp index 557e2450dce..d7b68afca1d 100644 --- a/src/mongo/db/pipeline/parsed_add_fields_test.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -47,69 +48,78 @@ namespace { using std::vector; // These ParsedAddFields spec tests are a subset of the ParsedAggregationProjection creation tests. -// ParsedAddField should behave the same way, but does not use the same creation, so we include +// ParsedAddFields should behave the same way, but does not use the same creation, so we include // an abbreviation of the same tests here. // Verify that ParsedAddFields rejects specifications with conflicting field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithConflictingFieldPaths) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // These specs contain the same exact path. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b" << 1 << "b" << 2))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << 3 << "_id" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b" << 1 << "b" << 2))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << 3 << "_id" << true)), + UserException); // These specs contain overlapping paths. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a.b" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a.b.c" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << true << "_id.x" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a.b" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a.b.c" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << true << "_id.x" << true)), + UserException); } // Verify that ParsedAddFields rejects specifications that contain invalid field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidFieldPath) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Dotted subfields are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); // The user cannot start a field with $. - ASSERT_THROWS(ParsedAddFields::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("c.$d" << true)), UserException); // Empty field names should throw an error. - ASSERT_THROWS(ParsedAddFields::create(BSON("" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << BSON("a" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON(".a" << true)), UserException); } // Verify that ParsedAddFields rejects specifications that contain empty objects or invalid // expressions. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidObjectsOrExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Invalid expressions should be rejected. - ASSERT_THROWS( - ParsedAddFields::create(BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), - UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("$gt" << BSON("bad" + ASSERT_THROWS(ParsedAddFields::create( + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); ASSERT_THROWS(ParsedAddFields::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); // Empty specifications are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSONObj()), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSONObj()), UserException); // Empty nested objects are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSONObj())), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSONObj())), UserException); } TEST(ParsedAddFields, DoesNotErrorOnTwoNestedFields) { - ParsedAddFields::create(BSON("a.b" << true << "a.c" << true)); - ParsedAddFields::create(BSON("a.b" << true << "a" << BSON("c" << true))); + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // Verify that replaced fields are not included as dependencies. TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -121,8 +131,9 @@ TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { // Verify that adding nested fields keeps the top-level field as a dependency. TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" << true)); + addition.parse(expCtx, BSON("x.y" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -135,8 +146,10 @@ TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFie // Verify that fields that an expression depends on are added to the dependencies. TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" + addition.parse(expCtx, + BSON("x.y" << "$z" << "a" << "$b")); @@ -155,8 +168,9 @@ TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { // Verify that the serialization produces the correct output: converting numbers and literals to // their corresponding $const form. TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + addition.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); auto expectedSerialization = Document( fromjson("{a: {$add: [\"$a\", {$const: 2}]}, b: {d: {$const: 3}}, x: {y: {$const: 4}}}")); @@ -168,8 +182,9 @@ TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { // Verify that serialize treats the _id field as any other field: including when explicity included. TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" << false)); + addition.parse(expCtx, BSON("_id" << false)); // Adds explicit "_id" setting field, serializes expressions. auto expectedSerialization = Document(fromjson("{_id: {$const: false}}")); @@ -184,8 +199,9 @@ TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { // yet they derive from the same parent class. If the parent class were to change, this test would // fail. TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); // Does not implicitly include "_id" field. auto expectedSerialization = Document(fromjson("{a: {$const: true}}")); @@ -197,8 +213,9 @@ TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { // Verify that the $addFields stage optimizes expressions into simpler forms when possible. TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"$const", 3}}}}; @@ -209,8 +226,9 @@ TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { // Verify that the $addFields stage optimizes expressions even when they are nested. TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"b", Document{{"$const", 3}}}}}}; @@ -225,8 +243,9 @@ TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { // Verify that a new field is added to the end of the document. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // There are no fields in the document. auto result = addition.applyProjection(Document{}); @@ -241,8 +260,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { // Verify that an existing field is replaced and stays in the same order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // Specified field is the only field in the document, and is replaced. auto result = addition.applyProjection(Document{{"c", 1}}); @@ -257,8 +277,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { // Verify that replacing multiple fields preserves the original field order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFieldOrder) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("second" + addition.parse(expCtx, + BSON("second" << "SECOND" << "first" << "FIRST")); @@ -269,8 +291,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFie // Verify that adding multiple fields adds the fields in the order specified. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecified) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "secondComputed" << "SECOND")); @@ -286,8 +310,10 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecif // Verify that both adding and replacing fields at the same time follows the same rules as doing // each independently. TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRulesAsSeparately) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "second" << "SECOND")); @@ -300,8 +326,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRules // Verify that _id is included just like a regular field, in whatever order it appears in the // input document, when adding new fields. TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("newField" + addition.parse(expCtx, + BSON("newField" << "computedVal")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"newField", "computedVal"_sd}}; @@ -314,8 +342,10 @@ TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) // Verify that replacing or adding _id works just like any other field. TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" + addition.parse(expCtx, + BSON("_id" << "newId")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "newId"_sd}, {"a", 1}}; @@ -336,8 +366,9 @@ TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { // Verify that adding a dotted field keeps the other fields in the subdocument. TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedFieldToSubDoc) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = addition.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -362,8 +393,9 @@ TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedF // Verify that creating a dotted field creates the subdocument structure necessary. TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("sub.target" << true)); + addition.parse(expCtx, BSON("sub.target" << true)); // Should add the path if it doesn't exist. auto result = addition.applyProjection(Document{}); @@ -379,8 +411,9 @@ TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) // Verify that adding a dotted value to an array field sets the field in every element of the array. // SERVER-25200: make this agree with $set. TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -404,8 +437,10 @@ TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { // Verify that creation of the subdocument structure works for many layers of nesting. TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedField) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b.c.d" + addition.parse(expCtx, + BSON("a.b.c.d" << "computedVal")); // Should add the path if it doesn't exist. @@ -421,8 +456,10 @@ TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedFiel // Verify that _id is not special: we can add subfields to it as well. TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id.X" << true << "_id.Z" + addition.parse(expCtx, + BSON("_id.X" << true << "_id.Z" << "NEW")); auto result = addition.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", true}, {"Y", 2}, {"Z", "NEW"_sd}}}}; @@ -432,10 +469,12 @@ TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { // Verify that both ways of specifying nested fields -- both dotted notation and nesting -- // can be used together in the same specification. TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - addition.parse(BSON("a.b" << true << "a.c" << true << "a.W" + addition.parse(expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << "W" << "a.X" << "X" @@ -462,8 +501,10 @@ TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { // Verify that adding nested fields preserves the addition order in the spec. TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("b.d" + addition.parse(expCtx, + BSON("b.d" << "FIRST" << "b.c" << "SECOND")); @@ -478,8 +519,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { // Verify that the metadata is kept from the original input document. TEST(ParsedAddFieldsExecutionTest, AlwaysKeepsMetadataFromOriginalDoc) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp index a73a9589ed8..69b196faab3 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp @@ -259,7 +259,7 @@ private: } // namespace std::unique_ptr<ParsedAggregationProjection> ParsedAggregationProjection::create( - const BSONObj& spec) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) { // Check that the specification was valid. Status returned is unspecific because validate() // is used by the $addFields stage as well as $project. // If there was an error, uassert with a $project-specific message. @@ -282,7 +282,7 @@ std::unique_ptr<ParsedAggregationProjection> ParsedAggregationProjection::create : static_cast<ParsedAggregationProjection*>(new ParsedExclusionProjection())); // Actually parse the specification. - parsedProject->parse(spec); + parsedProject->parse(expCtx, spec); return parsedProject; } diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.h b/src/mongo/db/pipeline/parsed_aggregation_projection.h index a637365565e..8d92e4ab768 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.h +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.h @@ -41,7 +41,7 @@ namespace mongo { class BSONObj; class Document; -struct ExpressionContext; +class ExpressionContext; namespace parsed_aggregation_projection { @@ -117,7 +117,8 @@ public: * * Throws a UserException if 'spec' is an invalid projection specification. */ - static std::unique_ptr<ParsedAggregationProjection> create(const BSONObj& spec); + static std::unique_ptr<ParsedAggregationProjection> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec); virtual ~ParsedAggregationProjection() = default; @@ -132,7 +133,8 @@ public: * inclusions and exclusions. 'variablesParseState' is used by any contained expressions to * track which variables are defined so that they can later be referenced at execution time. */ - virtual void parse(const BSONObj& spec) = 0; + virtual void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec) = 0; /** * Optimize any expressions contained within this projection. @@ -140,11 +142,6 @@ public: virtual void optimize() {} /** - * Inject the ExpressionContext into any expressions contained within this projection. - */ - virtual void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {} - - /** * Add any dependencies needed by this projection or any sub-expressions to 'deps'. */ virtual DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const { diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp index 5b55071a6bf..71dd1c2378f 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp @@ -36,6 +36,7 @@ #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/json.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,262 +54,296 @@ BSONObj wrapInLiteral(const T& arg) { // TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateFieldNames) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude the same field twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << false)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "b" << false))), UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "b" << false))), - UserException); // Mix of include/exclude and adding a field. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a" << true)), - UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a" << true)), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << wrapInLiteral(0))), UserException); // Adding the same field twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateIds) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << false << "_id" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id" << true)), UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "_id" << false)), UserException); + // Mix of including/excluding and adding _id. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << true)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "_id" << wrapInLiteral(0))), + UserException); + // Adding _id twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldsWithSharedPrefix) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude Fields with a shared prefix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a.b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a.b" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << false << "a" << false)), UserException); // Mix of include/exclude and adding a shared prefix. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a.b" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << false << "a" << wrapInLiteral(0))), + UserException); // Adding a shared prefix twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), + expCtx, BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfIdAndSubFieldsOfId) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id.x" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << false)), - UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id.x" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false << "_id" << false)), UserException); - // Adding _id twice. + // Mix of including/excluding and adding _id. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + expCtx, BSON("_id.x" << false << "_id" << wrapInLiteral(0))), UserException); + + // Adding _id twice. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfInclusionAndExclusion) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Simple mix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << false << "c" << true))), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << BSON("b" << false << "c" << true))), + UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("_id.b" << false << "a.c" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.b" << false << "a.c" << true)), - UserException); // Mix while also adding a field. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), + expCtx, BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), + expCtx, BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), UserException); // Mixing "_id" inclusion with exclusion. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a" << false)), - UserException); - - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "_id" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a.b.c" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "_id" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << true << "a.b.c" << false)), - UserException); -} - -TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "b" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a.b.c" << false)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "b" << false)), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true << "a.b.c" << false)), UserException); +} +TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << wrapInLiteral(1))), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "b" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), + expCtx, BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), + expCtx, BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDottedFieldInSubDocument) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << true))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldNamesStartingWithADollar) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 1)), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 1)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 0)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 1)), + UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 0))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 0))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 1))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 1))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 1)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectTopLevelExpressions) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << BSON_ARRAY(4 << 2))), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << BSON_ARRAY(4 << 2))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectExpressionWithMultipleFieldNames) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + expCtx, BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyProjection) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSONObj()), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSONObj()), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyNestedObject) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << BSONObj())), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << BSONObj())), + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << BSONObj())), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << BSONObj())), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b" << BSONObj()))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << BSONObj()))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidExpression) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidFieldPath) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Empty field names. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << wrapInLiteral(2))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << false)), UserException); // Not testing field names with null bytes, since that is invalid BSON, and won't make it to the // $project stage without a previous error. // Field names starting with '$'. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$x" << wrapInLiteral(2))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$x" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << false)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << true)); - ParsedAggregationProjection::create(BSON("a.b" << true << "a" << BSON("c" << true))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // @@ -316,102 +351,112 @@ TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { // TEST(ParsedAggregationProjectionType, ShouldDefaultToInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("_id" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectExclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a.b.c" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << true)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldTreatOnlyComputedFieldsAsAnInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("_id" << BSON("x" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("x" << BSON("_id" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldAllowMixOfInclusionAndComputedFields) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); auto parsedProject = - ParsedAggregationProjection::create(BSON("a" << true << "b" << wrapInLiteral(1))); + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("a.b" << true << "a.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = ParsedAggregationProjection::create( - BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); + expCtx, BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldCoerceNumericsToBools) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); std::vector<Value> zeros = {Value(0), Value(0LL), Value(0.0), Value(Decimal128(0))}; for (auto&& zero : zeros) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", zero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", zero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } std::vector<Value> nonZeroes = { Value(1), Value(-1), Value(3), Value(1LL), Value(1.0), Value(Decimal128(1))}; for (auto&& nonZero : nonZeroes) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", nonZero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", nonZero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp index 8226d146009..0f800f9a112 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp @@ -143,7 +143,10 @@ Document ParsedExclusionProjection::applyProjection(Document inputDoc) const { return _root->applyProjection(inputDoc); } -void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, size_t depth) { +void ParsedExclusionProjection::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth) { for (auto elem : spec) { const auto fieldName = elem.fieldNameStringData().toString(); @@ -188,7 +191,7 @@ void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, child = child->addOrGetChild(fullPath.fullPath()); } - parse(elem.Obj(), child, depth + 1); + parse(expCtx, elem.Obj(), child, depth + 1); break; } default: { MONGO_UNREACHABLE; } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.h b/src/mongo/db/pipeline/parsed_exclusion_projection.h index ea7b25ac33f..d0988d2d2cb 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.h @@ -108,8 +108,8 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { - parse(spec, _root.get(), 0); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { + parse(expCtx, spec, _root.get(), 0); } /** @@ -134,7 +134,10 @@ private: * Traverses 'spec' and parses each field. Adds any excluded fields at this level to 'node', * and recurses on any sub-objects. */ - void parse(const BSONObj& spec, ExclusionNode* node, size_t depth); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth); // The ExclusionNode tree does most of the execution work once constructed. diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp index 62a23b2c729..76b458260d0 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp @@ -40,6 +40,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" @@ -57,28 +58,32 @@ DEATH_TEST(ExclusionProjection, ShouldRejectComputedField, "Invariant failure fieldName[0] != '$'") { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Top-level expression. - exclusion.parse(BSON("a" << false << "b" << BSON("$literal" << 1))); + exclusion.parse(expCtx, BSON("a" << false << "b" << BSON("$literal" << 1))); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedField, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << true)); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedId, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << true << "a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << true << "a" << false)); } TEST(ExclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); exclusion.parse( - fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); + expCtx, fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); // Converts numbers to bools, converts dotted paths to nested documents. Note order of excluded // fields is subject to change. @@ -107,7 +112,9 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { // later. If there are no later stages, then we will finish the dependency computation // cycle without full knowledge of which fields are needed, and thus include all the fields. ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); DepsTracker deps; exclusion.addDependencies(&deps); @@ -119,7 +126,8 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false << "a" << false << "b.c" << false)); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -131,7 +139,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNestedObj) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -146,7 +155,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNes TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); // More than one field in document. auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -171,7 +181,9 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" << Value(Decimal128(0)))); auto result = @@ -182,7 +194,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("second" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("second" << false)); auto result = exclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"first", 0}, {"third", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -190,7 +203,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}, {"_id", "ID"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -198,7 +212,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false << "_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false << "_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -206,7 +221,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}, {"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -218,7 +234,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id.x" << false << "_id" << BSON("y" << false))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id.x" << false << "_id" << BSON("y" << false))); auto result = exclusion.applyProjection( Document{{"_id", Document{{"x", 1}, {"y", 2}, {"z", 3}}}, {"a", 1}}); auto expectedResult = Document{{"_id", Document{{"z", 3}}}, {"a", 1}}; @@ -227,7 +244,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); // More than one field in sub document. auto result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -252,7 +270,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFieldDoesNotExist) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("sub.target" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("sub.target" << false)); // Should not add the path if it doesn't exist. auto result = exclusion.applyProjection(Document{}); @@ -267,7 +286,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFiel TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementInArray) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); std::vector<Value> nestedValues = { Value(1), @@ -290,8 +310,10 @@ TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementIn TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Exclude all of "a.b", "a.c", "a.d", and "a.e". exclusion.parse( + expCtx, BSON("a.b" << false << "a.c" << false << "a" << BSON("d" << false << "e" << false))); auto result = exclusion.applyProjection( Document{{"a", Document{{"b", 1}, {"c", 2}, {"d", 3}, {"e", 4}, {"f", 5}}}}); @@ -301,7 +323,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); MutableDocument inputDocBuilder(Document{{"_id", "ID"_sd}, {"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp index abac372be7e..511d1b37fff 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp @@ -54,16 +54,6 @@ void InclusionNode::optimize() { } } -void InclusionNode::injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - for (auto&& expressionIt : _expressions) { - expressionIt.second->injectExpressionContext(expCtx); - } - - for (auto&& childPair : _children) { - childPair.second->injectExpressionContext(expCtx); - } -} - void InclusionNode::serialize(MutableDocument* output, bool explain) const { // Always put "_id" first if it was included (implicitly or explicitly). if (_inclusions.find("_id") != _inclusions.end()) { @@ -252,7 +242,8 @@ void InclusionNode::addPreservedPaths(std::set<std::string>* preservedPaths) con // ParsedInclusionProjection // -void ParsedInclusionProjection::parse(const BSONObj& spec, +void ParsedInclusionProjection::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, const VariablesParseState& variablesParseState) { // It is illegal to specify a projection with no output fields. bool atLeastOneFieldInOutput = false; @@ -289,7 +280,7 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, } case BSONType::Object: { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. break; } @@ -306,13 +297,14 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. - _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + _root->addComputedField( + FieldPath(elem.fieldName()), + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -343,6 +335,7 @@ Document ParsedInclusionProjection::applyProjection(Document inputDoc, Variables } bool ParsedInclusionProjection::parseObjectAsExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { @@ -351,15 +344,17 @@ bool ParsedInclusionProjection::parseObjectAsExpression( // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, - const VariablesParseState& variablesParseState, - InclusionNode* node) { +void ParsedInclusionProjection::parseSubObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, + const VariablesParseState& variablesParseState, + InclusionNode* node) { for (auto elem : subObj) { invariant(elem.fieldName()[0] != '$'); // Dotted paths in a sub-object have already been disallowed in @@ -381,19 +376,20 @@ void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { break; } auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.h b/src/mongo/db/pipeline/parsed_inclusion_projection.h index 8cd04514b53..2f812af300d 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.h @@ -119,8 +119,6 @@ public: return _pathToNode; } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); - /** * Recursively add all paths that are preserved by this inclusion projection. */ @@ -184,10 +182,10 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique<Variables>(idGenerator.getIdCount()); } @@ -210,10 +208,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::EXHAUSTIVE_FIELDS; @@ -246,7 +240,9 @@ private: * Parses 'spec' to determine which fields to include, which are computed, and whether to * include '_id' or not. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -256,7 +252,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -264,7 +261,8 @@ private: * Traverses 'subObj' and parses each field. Adds any included or computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp index 611f5367e9a..c2610418a40 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,19 +54,23 @@ BSONObj wrapInLiteral(const T& arg) { TEST(InclusionProjection, ShouldThrowWhenParsingInvalidExpression) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("a" << BSON("$gt" << BSON("bad" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); } TEST(InclusionProjection, ShouldRejectProjectionWithNoOutputFields) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("_id" << false)), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, BSON("_id" << false)), UserException); } TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true << "x.y" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true << "x.y" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -78,7 +83,8 @@ TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -90,7 +96,9 @@ TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << "$a" << "x" << "$z")); @@ -106,7 +114,9 @@ TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("x.y" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("x.y" << "$z")); DepsTracker deps; @@ -123,7 +133,8 @@ TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedInclusionProjection inclusion; - inclusion.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document(fromjson( @@ -136,7 +147,8 @@ TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true)); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document{{"_id", false}, {"a", true}}; @@ -149,7 +161,8 @@ TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -162,7 +175,8 @@ TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -176,10 +190,13 @@ TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON( - "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true - << "e.f" - << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse( + expCtx, + BSON("a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" + << true + << "e.f" + << true)); auto modifiedPaths = inclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kAllExcept); @@ -195,7 +212,9 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWithIdExclusion) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true @@ -222,7 +241,8 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWith TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); // More than one field in document. auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -247,7 +267,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -260,7 +281,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -268,7 +290,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedField TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("first" << true << "second" << true << "third" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("first" << true << "second" << true << "third" << true)); auto inputDoc = Document{{"second", 1}, {"first", 0}, {"third", 2}}; auto result = inclusion.applyProjection(inputDoc); ASSERT_DOCUMENT_EQ(result, inputDoc); @@ -276,7 +299,9 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"firstComputed", "FIRST"_sd}, {"secondComputed", "SECOND"_sd}}; @@ -285,7 +310,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -298,7 +324,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -306,7 +333,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFiel TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << true << "b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << true << "b" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}, {"c", 3}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}; @@ -315,7 +343,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << false)); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -323,7 +352,8 @@ TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << wrapInLiteral("newId"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << wrapInLiteral("newId"))); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"_id", "newId"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -335,7 +365,8 @@ TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -360,7 +391,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFieldDoesNotExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << true)); // Should not add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -375,7 +407,8 @@ TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFiel TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -399,7 +432,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementIn TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Other fields exist in sub document, one of which is the specified field. auto result = inclusion.applyProjection(Document{{"sub", Document{{"target", 1}, {"c", 2}}}}); @@ -419,7 +453,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDoesntExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -433,7 +468,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDo TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayToComputedField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b.c.d" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b.c.d" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -448,7 +484,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayTo TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << wrapInLiteral("COMPUTED"))); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -471,7 +508,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElement TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -498,7 +536,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachEl TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); auto result = inclusion.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", 1}, {"Z", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -506,13 +545,16 @@ TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedInclusionProjection inclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - inclusion.parse(BSON( - "a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" << wrapInLiteral("X") - << "a" - << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" - << wrapInLiteral("Z")))); + inclusion.parse( + expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" + << wrapInLiteral("X") + << "a" + << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" + << wrapInLiteral("Z")))); auto result = inclusion.applyProjection(Document{ {"a", Document{{"b", "b"_sd}, {"c", "c"_sd}, {"d", "d"_sd}, {"e", "e"_sd}, {"f", "f"_sd}}}}); @@ -530,7 +572,9 @@ TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"a", "FIRST"_sd}, {"b", Document{{"c", "SECOND"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -538,7 +582,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpe TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", Document{{"c", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -557,7 +602,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppearAfterInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"b", 1}, {"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", "NEW"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -572,7 +618,8 @@ TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppea TEST(InclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/pipeline.cpp b/src/mongo/db/pipeline/pipeline.cpp index b296d3f85df..8bb745fba13 100644 --- a/src/mongo/db/pipeline/pipeline.cpp +++ b/src/mongo/db/pipeline/pipeline.cpp @@ -165,13 +165,6 @@ void Pipeline::reattachToOperationContext(OperationContext* opCtx) { } } -void Pipeline::injectExpressionContext(const intrusive_ptr<ExpressionContext>& expCtx) { - pCtx = expCtx; - for (auto&& stage : _sources) { - stage->injectExpressionContext(pCtx); - } -} - intrusive_ptr<Pipeline> Pipeline::splitForSharded() { // Create and initialize the shard spec we'll return. We start with an empty pipeline on the // shards and all work being done in the merger. Optimizations can move operations between diff --git a/src/mongo/db/pipeline/pipeline.h b/src/mongo/db/pipeline/pipeline.h index aff72fa40bc..fe4ca1de424 100644 --- a/src/mongo/db/pipeline/pipeline.h +++ b/src/mongo/db/pipeline/pipeline.h @@ -45,7 +45,7 @@ class BSONObjBuilder; class CollatorInterface; class DocumentSource; class OperationContext; -struct ExpressionContext; +class ExpressionContext; /** * A Pipeline object represents a list of DocumentSources and is responsible for optimizing the @@ -129,12 +129,6 @@ public: void optimizePipeline(); /** - * Propagates a reference to the ExpressionContext to all of the pipeline's contained stages and - * expressions. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); - - /** * Returns any other collections involved in the pipeline in addition to the collection the * aggregation is run on. */ diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index bb681edd3f1..b2a7bee2fa8 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -188,7 +188,6 @@ public: return pipeline.getStatus(); } - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); AutoGetCollectionForRead autoColl(expCtx->opCtx, expCtx->ns); @@ -294,7 +293,8 @@ StatusWith<std::unique_ptr<PlanExecutor>> attemptToGetExecutor( // that the user omitted. // // If pipeline has a null collator (representing the "simple" collation), we simply set the - // collation option to the original user BSON. + // collation option to the original user BSON, which is either the empty object (unspecified), + // or the specification for the "simple" collation. qr->setCollation(pExpCtx->getCollator() ? pExpCtx->getCollator()->getSpec().toBSON() : pExpCtx->collation); diff --git a/src/mongo/db/pipeline/pipeline_d.h b/src/mongo/db/pipeline/pipeline_d.h index 35d93a4cfd2..5be4893cf3b 100644 --- a/src/mongo/db/pipeline/pipeline_d.h +++ b/src/mongo/db/pipeline/pipeline_d.h @@ -38,7 +38,7 @@ namespace mongo { class Collection; class DocumentSourceCursor; class DocumentSourceSort; -struct ExpressionContext; +class ExpressionContext; class OperationContext; class Pipeline; class PlanExecutor; diff --git a/src/mongo/db/pipeline/pipeline_test.cpp b/src/mongo/db/pipeline/pipeline_test.cpp index 41cd6e567d3..dba575017ce 100644 --- a/src/mongo/db/pipeline/pipeline_test.cpp +++ b/src/mongo/db/pipeline/pipeline_test.cpp @@ -39,7 +39,7 @@ #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/collation/collator_interface_mock.h" @@ -86,7 +86,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr<ExpressionContextForTest> ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector<BSONObj>{}}); + auto outputPipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); outputPipe->optimizePipeline(); @@ -240,17 +247,17 @@ class MoveMatchBeforeSort : public Base { class LookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -259,17 +266,17 @@ class LookupShouldCoalesceWithUnwindOnAs : public Base { class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -278,18 +285,18 @@ class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -298,13 +305,13 @@ class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -313,51 +320,59 @@ class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class LookupShouldSwapWithMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0}}]"; } string outputPipeJson() { return "[{$match: {independent: 0}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}]"; } }; class LookupShouldSplitMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0, asField: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 0}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {asField: {$eq: 3}}}]"; } }; class LookupShouldNotAbsorbMatchOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } }; class LookupShouldAbsorbUnwindMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: '$asField'}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}, " " matching: {subfield: {$eq: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: {path: '$asField'}}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -365,14 +380,15 @@ class LookupShouldAbsorbUnwindMatch : public Base { class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {'asField.subfield': {$eq: 1}, independentField: {$gt: 2}}}]"; } string outputPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -386,7 +402,8 @@ class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { } string serializedPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -397,19 +414,21 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { // the $lookup, and if any child of the $or is independent of the 'asField', then the $match // cannot be absorbed by the $lookup. string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; @@ -418,14 +437,15 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {index: 0, 'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -438,7 +458,8 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {$and: [{index: {$eq: 0}}, {'asField.value': {$gt: 0}}]}}]"; } @@ -446,14 +467,15 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -465,7 +487,8 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}}}]"; } @@ -473,13 +496,13 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas class LookupDoesNotAbsorbElemMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: '$x'}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string outputPipeJson() { return "[{$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'x', " " localField: 'y', " " foreignField: 'z', " @@ -491,7 +514,7 @@ class LookupDoesNotAbsorbElemMatch : public Base { " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x'}}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } @@ -499,35 +522,35 @@ class LookupDoesNotAbsorbElemMatch : public Base { class LookupDoesSwapWithMatchOnLocalField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {y: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {y: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesSwapWithMatchOnFieldWithSameNameAsForeignField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {z: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {z: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesNotAbsorbUnwindOnSubfieldOfAsButStillMovesMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x.subfield'}}, " " {$match: {'independent': 2, 'x.dependent': 2}}]"; } string outputPipeJson() { return "[{$match: {'independent': {$eq: 2}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {'x.dependent': {$eq: 2}}}, " " {$unwind: {path: '$x.subfield'}}]"; } @@ -639,58 +662,61 @@ class UnwindBeforeDoubleMatchShouldRepeatedlyOptimize : public Base { class GraphLookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$out'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: " - "false}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out'}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: true}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: false, " - " includeArrayIndex: 'index'}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false, " + " includeArrayIndex: 'index'}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', " " startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } @@ -698,14 +724,14 @@ class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$nottherightthing'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$nottherightthing'}}]"; } }; @@ -713,7 +739,7 @@ class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class GraphLookupShouldSwapWithMatch : public Base { string inputPipeJson() { return "[{$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -725,7 +751,7 @@ class GraphLookupShouldSwapWithMatch : public Base { string outputPipeJson() { return "[{$match: {independent: 'x'}}," " {$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -863,7 +889,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr<ExpressionContextForTest> ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector<BSONObj>{}}); + mergePipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); mergePipe->optimizePipeline(); @@ -1057,7 +1090,7 @@ namespace coalesceLookUpAndUnwind { class ShouldCoalesceUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -1066,14 +1099,14 @@ class ShouldCoalesceUnwindOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } }; class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -1082,14 +1115,14 @@ class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } }; class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -1098,7 +1131,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } @@ -1106,7 +1139,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { class ShouldNotCoalesceUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1115,7 +1148,7 @@ class ShouldNotCoalesceUnwindNotOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1170,14 +1203,14 @@ class LookUp : public needsPrimaryShardMergerBase { return true; } string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } string shardPipeJson() { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } }; @@ -1192,7 +1225,7 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { OperationContextNoop _opCtx; const std::vector<BSONObj> rawPipeline = { fromjson("{$geoNear: {distanceField: 'd', near: [0, 0], query: {a: 1}}}")}; - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext( + intrusive_ptr<ExpressionContextForTest> ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 1)); @@ -1201,28 +1234,13 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { TEST(PipelineInitialSource, MatchInitialQuery) { OperationContextNoop _opCtx; const std::vector<BSONObj> rawPipeline = {fromjson("{$match: {'a': 4}}")}; - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext( + intrusive_ptr<ExpressionContextForTest> ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 4)); } -TEST(PipelineInitialSource, ParseCollation) { - QueryTestServiceContext serviceContext; - auto opCtx = serviceContext.makeOperationContext(); - - const BSONObj inputBson = - fromjson("{pipeline: [{$match: {a: 'abc'}}], collation: {locale: 'reverse'}}"); - auto request = AggregationRequest::parseFromBSON(NamespaceString("a.collection"), inputBson); - ASSERT_OK(request.getStatus()); - - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(opCtx.get(), request.getValue()); - ASSERT(ctx->getCollator()); - CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); - ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->getCollator(), &collator)); -} - namespace Dependencies { using PipelineDependenciesTest = AggregationContextFixture; diff --git a/src/mongo/db/views/view_catalog.cpp b/src/mongo/db/views/view_catalog.cpp index 379f0364839..d8e0e842824 100644 --- a/src/mongo/db/views/view_catalog.cpp +++ b/src/mongo/db/views/view_catalog.cpp @@ -43,6 +43,7 @@ #include "mongo/db/pipeline/aggregation_request.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/lite_parsed_pipeline.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/server_options.h" @@ -156,10 +157,24 @@ Status ViewCatalog::_upsertIntoGraph(OperationContext* txn, const ViewDefinition // Performs the insert into the graph. auto doInsert = [this, &txn](const ViewDefinition& viewDef, bool needsValidation) -> Status { - // Parse the pipeline for this view to get the namespaces it references. + // Make a LiteParsedPipeline to determine the namespaces referenced by this pipeline. AggregationRequest request(viewDef.viewOn(), viewDef.pipeline()); - boost::intrusive_ptr<ExpressionContext> expCtx = new ExpressionContext(txn, request); - expCtx->setCollator(CollatorInterface::cloneCollator(viewDef.defaultCollator())); + const LiteParsedPipeline liteParsedPipeline(request); + const auto involvedNamespaces = liteParsedPipeline.getInvolvedNamespaces(); + + // Verify that this is a legitimate pipeline specification by making sure it parses + // correctly. In order to parse a pipeline we need to resolve any namespaces involved to a + // collection and a pipeline, but in this case we don't need this map to be accurate since + // we will not be evaluating the pipeline. + StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces; + for (auto&& nss : liteParsedPipeline.getInvolvedNamespaces()) { + resolvedNamespaces[nss.coll()] = {nss, {}}; + } + boost::intrusive_ptr<ExpressionContext> expCtx = + new ExpressionContext(txn, + request, + CollatorInterface::cloneCollator(viewDef.defaultCollator()), + std::move(resolvedNamespaces)); auto pipelineStatus = Pipeline::parse(viewDef.pipeline(), expCtx); if (!pipelineStatus.isOK()) { uassert(40255, @@ -169,7 +184,7 @@ Status ViewCatalog::_upsertIntoGraph(OperationContext* txn, const ViewDefinition return pipelineStatus.getStatus(); } - std::vector<NamespaceString> refs = pipelineStatus.getValue()->getInvolvedCollections(); + std::vector<NamespaceString> refs(involvedNamespaces.begin(), involvedNamespaces.end()); refs.push_back(viewDef.viewOn()); int pipelineSize = 0; |