diff options
author | Charlie Swanson <charlie.swanson@mongodb.com> | 2016-12-13 10:15:08 -0500 |
---|---|---|
committer | Charlie Swanson <charlie.swanson@mongodb.com> | 2016-12-16 16:24:32 -0500 |
commit | 37e720678f6e468726c6cc775a5dc898d080f0f3 (patch) | |
tree | 4bd6b4932cc0ac436c0d7c949f7e37df613684d2 /src/mongo | |
parent | 0cd2bf29d5798a395a07e67ae79ede9a5cefd411 (diff) | |
download | mongo-37e720678f6e468726c6cc775a5dc898d080f0f3.tar.gz |
SERVER-25535 Remove injectExpressionContext().
These methods were formally used to propagate a new ExpressionContext to
stages, accumulators, or expressions which potentially needed to
comparisons. Originally, this was necessary since Pipeline parsing
happened outside of the collection lock and thus could not determine if
there was a default collation on the collection. This meant that the
collation could change after parsing and any operators that might
compare strings would need to know about it.
We have since moved parsing within the lock, so the collation can be
known at parse time and the ExpressionContext should not change. This
patch requires an ExpressionContext at construction time, and disallows
changing the collation on an ExpressionContext.
Diffstat (limited to 'src/mongo')
98 files changed, 1916 insertions, 1416 deletions
diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp index 7fe8d54992d..bb426950b44 100644 --- a/src/mongo/db/commands/pipeline_command.cpp +++ b/src/mongo/db/commands/pipeline_command.cpp @@ -258,7 +258,6 @@ boost::intrusive_ptr<Pipeline> reparsePipeline( fassertFailedWithStatusNoTrace(40175, reparsedPipeline.getStatus()); } - reparsedPipeline.getValue()->injectExpressionContext(expCtx); reparsedPipeline.getValue()->optimizePipeline(); return reparsedPipeline.getValue(); } @@ -350,18 +349,15 @@ public: // For operations on views, this will be the underlying namespace. const NamespaceString& nss = request.getNamespaceString(); - // Set up the ExpressionContext. - intrusive_ptr<ExpressionContext> expCtx = new ExpressionContext(txn, request); - expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp"; - - auto resolvedNamespaces = resolveInvolvedNamespaces(txn, request); - if (!resolvedNamespaces.isOK()) { - return appendCommandStatus(result, resolvedNamespaces.getStatus()); - } - expCtx->resolvedNamespaces = std::move(resolvedNamespaces.getValue()); + // Parse the user-specified collation, if any. + std::unique_ptr<CollatorInterface> userSpecifiedCollator = request.getCollation().isEmpty() + ? nullptr + : uassertStatusOK(CollatorFactoryInterface::get(txn->getServiceContext()) + ->makeFromBSON(request.getCollation())); boost::optional<ClientCursorPin> pin; // either this OR the exec will be non-null unique_ptr<PlanExecutor> exec; + boost::intrusive_ptr<ExpressionContext> expCtx; boost::intrusive_ptr<Pipeline> pipeline; auto curOp = CurOp::get(txn); { @@ -387,7 +383,7 @@ public: // means that no collation was specified. if (!request.getCollation().isEmpty()) { if (!CollatorInterface::collatorsMatch(ctx.getView()->defaultCollator(), - expCtx->getCollator())) { + userSpecifiedCollator.get())) { return appendCommandStatus(result, {ErrorCodes::OptionNotSupportedOnView, "Cannot override a view's default collation"}); @@ -440,14 +436,26 @@ public: return status; } + // Determine the appropriate collation to make the ExpressionContext. + // If the pipeline does not have a user-specified collation, set it from the collection - // default. + // default. Be careful to consult the original request BSON to check if a collation was + // specified, since a specification of {locale: "simple"} will result in a null + // collator. + auto collatorToUse = std::move(userSpecifiedCollator); if (request.getCollation().isEmpty() && collection && collection->getDefaultCollator()) { - invariant(!expCtx->getCollator()); - expCtx->setCollator(collection->getDefaultCollator()->clone()); + invariant(!collatorToUse); + collatorToUse = collection->getDefaultCollator()->clone(); } + expCtx.reset( + new ExpressionContext(txn, + request, + std::move(collatorToUse), + uassertStatusOK(resolveInvolvedNamespaces(txn, request)))); + expCtx->tempDir = storageGlobalParams.dbpath + "/_tmp"; + // Parse the pipeline. auto statusWithPipeline = Pipeline::parse(request.getPipeline(), expCtx); if (!statusWithPipeline.isOK()) { @@ -463,14 +471,6 @@ public: return appendCommandStatus(result, pipelineCollationStatus); } - // Propagate the ExpressionContext throughout all of the pipeline's stages and - // expressions. - pipeline->injectExpressionContext(expCtx); - - // The pipeline must be optimized after the correct collator has been set on it (by - // injecting the ExpressionContext containing the collator). This is necessary because - // optimization may make string comparisons, e.g. optimizing {$eq: [<str1>, <str2>]} to - // a constant. pipeline->optimizePipeline(); if (kDebugBuild && !expCtx->isExplain && !expCtx->inShard) { diff --git a/src/mongo/db/pipeline/accumulation_statement.cpp b/src/mongo/db/pipeline/accumulation_statement.cpp index b191a78fdf1..9ac394b0018 100644 --- a/src/mongo/db/pipeline/accumulation_statement.cpp +++ b/src/mongo/db/pipeline/accumulation_statement.cpp @@ -64,7 +64,9 @@ Accumulator::Factory AccumulationStatement::getFactory(StringData name) { } AccumulationStatement AccumulationStatement::parseAccumulationStatement( - const BSONElement& elem, const VariablesParseState& vps) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& elem, + const VariablesParseState& vps) { auto fieldName = elem.fieldNameStringData(); uassert(40234, str::stream() << "The field '" << fieldName << "' must be an accumulator object", @@ -91,7 +93,7 @@ AccumulationStatement AccumulationStatement::parseAccumulationStatement( return {fieldName.toString(), AccumulationStatement::getFactory(accName), - Expression::parseOperand(specElem, vps)}; + Expression::parseOperand(expCtx, specElem, vps)}; } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h index 5a7c578ff65..92f7906ed9e 100644 --- a/src/mongo/db/pipeline/accumulation_statement.h +++ b/src/mongo/db/pipeline/accumulation_statement.h @@ -72,8 +72,10 @@ public: * * Throws a UserException if parsing fails. */ - static AccumulationStatement parseAccumulationStatement(const BSONElement& elem, - const VariablesParseState& vps); + static AccumulationStatement parseAccumulationStatement( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& elem, + const VariablesParseState& vps); /** * Registers an Accumulator with a parsing function, so that when an accumulator with the given diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index e0a2dbdb11a..28645e90a9b 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -48,9 +48,10 @@ namespace mongo { class Accumulator : public RefCountable { public: - using Factory = boost::intrusive_ptr<Accumulator> (*)(); + using Factory = boost::intrusive_ptr<Accumulator> (*)( + const boost::intrusive_ptr<ExpressionContext>& expCtx); - Accumulator() = default; + Accumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} /** Process input and update internal state. * merging should be true when processing outputs from getValue(true). @@ -84,26 +85,10 @@ public: return false; } - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Accumulator. - * Construction of accumulators is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: /// Update subclass's internal state based on input virtual void processInternal(const Value& input, bool merging) = 0; - /** - * Accumulators which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} - const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { return _expCtx; } @@ -118,14 +103,15 @@ private: class AccumulatorAddToSet final : public Accumulator { public: - AccumulatorAddToSet(); + explicit AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { return true; @@ -135,26 +121,22 @@ public: return true; } - void doInjectExpressionContext() final; - private: - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected, since this set must use the comparator's definition of - // equality. - boost::optional<ValueUnorderedSet> _set; + ValueUnorderedSet _set; }; class AccumulatorFirst final : public Accumulator { public: - AccumulatorFirst(); + explicit AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: bool _haveFirst; @@ -164,14 +146,15 @@ private: class AccumulatorLast final : public Accumulator { public: - AccumulatorLast(); + explicit AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: Value _last; @@ -180,14 +163,15 @@ private: class AccumulatorSum final : public Accumulator { public: - AccumulatorSum(); + explicit AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); bool isAssociative() const final { return true; @@ -211,7 +195,7 @@ public: MAX = -1, // Used to "scale" comparison. }; - explicit AccumulatorMinMax(Sense sense); + AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx, Sense sense); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -233,27 +217,32 @@ private: class AccumulatorMax final : public AccumulatorMinMax { public: - AccumulatorMax() : AccumulatorMinMax(MAX) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorMax(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorMinMax(expCtx, MAX) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorMin final : public AccumulatorMinMax { public: - AccumulatorMin() : AccumulatorMinMax(MIN) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorMin(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorMinMax(expCtx, MIN) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorPush final : public Accumulator { public: - AccumulatorPush(); + explicit AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: std::vector<Value> vpValue; @@ -262,14 +251,15 @@ private: class AccumulatorAvg final : public Accumulator { public: - AccumulatorAvg(); + explicit AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr<Accumulator> create(); + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: /** @@ -287,7 +277,7 @@ private: class AccumulatorStdDev : public Accumulator { public: - explicit AccumulatorStdDev(bool isSamp); + AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, bool isSamp); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -303,13 +293,17 @@ private: class AccumulatorStdDevPop final : public AccumulatorStdDev { public: - AccumulatorStdDevPop() : AccumulatorStdDev(false) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorStdDevPop(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorStdDev(expCtx, false) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; class AccumulatorStdDevSamp final : public AccumulatorStdDev { public: - AccumulatorStdDevSamp() : AccumulatorStdDev(true) {} - static boost::intrusive_ptr<Accumulator> create(); + explicit AccumulatorStdDevSamp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : AccumulatorStdDev(expCtx, true) {} + static boost::intrusive_ptr<Accumulator> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); }; } diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index 6a54e846a85..00845b32fc8 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -48,7 +48,7 @@ const char* AccumulatorAddToSet::getOpName() const { void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { if (!merging) { if (!input.missing()) { - bool inserted = _set->insert(input).second; + bool inserted = _set.insert(input).second; if (inserted) { _memUsageBytes += input.getApproximateSize(); } @@ -62,7 +62,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { const vector<Value>& array = input.getArray(); for (size_t i = 0; i < array.size(); i++) { - bool inserted = _set->insert(array[i]).second; + bool inserted = _set.insert(array[i]).second; if (inserted) { _memUsageBytes += array[i].getApproximateSize(); } @@ -71,10 +71,11 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { } Value AccumulatorAddToSet::getValue(bool toBeMerged) const { - return Value(vector<Value>(_set->begin(), _set->end())); + return Value(vector<Value>(_set.begin(), _set.end())); } -AccumulatorAddToSet::AccumulatorAddToSet() { +AccumulatorAddToSet::AccumulatorAddToSet(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _set(expCtx->getValueComparator().makeUnorderedValueSet()) { _memUsageBytes = sizeof(*this); } @@ -83,12 +84,9 @@ void AccumulatorAddToSet::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorAddToSet::create() { - return new AccumulatorAddToSet(); -} - -void AccumulatorAddToSet::doInjectExpressionContext() { - _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet(); +intrusive_ptr<Accumulator> AccumulatorAddToSet::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorAddToSet(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index c977ba99c15..da0c08fce9c 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -92,8 +92,9 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) { _count++; } -intrusive_ptr<Accumulator> AccumulatorAvg::create() { - return new AccumulatorAvg(); +intrusive_ptr<Accumulator> AccumulatorAvg::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorAvg(expCtx); } Decimal128 AccumulatorAvg::_getDecimalTotal() const { @@ -120,7 +121,8 @@ Value AccumulatorAvg::getValue(bool toBeMerged) const { return Value(_nonDecimalTotal.getDouble() / static_cast<double>(_count)); } -AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) { +AccumulatorAvg::AccumulatorAvg(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _isDecimal(false), _count(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_first.cpp b/src/mongo/db/pipeline/accumulator_first.cpp index 68ea563ff42..3092bd0e9da 100644 --- a/src/mongo/db/pipeline/accumulator_first.cpp +++ b/src/mongo/db/pipeline/accumulator_first.cpp @@ -57,7 +57,8 @@ Value AccumulatorFirst::getValue(bool toBeMerged) const { return _first; } -AccumulatorFirst::AccumulatorFirst() : _haveFirst(false) { +AccumulatorFirst::AccumulatorFirst(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx), _haveFirst(false) { _memUsageBytes = sizeof(*this); } @@ -68,7 +69,8 @@ void AccumulatorFirst::reset() { } -intrusive_ptr<Accumulator> AccumulatorFirst::create() { - return new AccumulatorFirst(); +intrusive_ptr<Accumulator> AccumulatorFirst::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorFirst(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_last.cpp b/src/mongo/db/pipeline/accumulator_last.cpp index 64f545f7db1..5f465b1bf4e 100644 --- a/src/mongo/db/pipeline/accumulator_last.cpp +++ b/src/mongo/db/pipeline/accumulator_last.cpp @@ -53,7 +53,8 @@ Value AccumulatorLast::getValue(bool toBeMerged) const { return _last; } -AccumulatorLast::AccumulatorLast() { +AccumulatorLast::AccumulatorLast(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -62,7 +63,8 @@ void AccumulatorLast::reset() { _last = Value(); } -intrusive_ptr<Accumulator> AccumulatorLast::create() { - return new AccumulatorLast(); +intrusive_ptr<Accumulator> AccumulatorLast::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorLast(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index 08db8b9e0fc..f5c4b254b88 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -68,7 +68,9 @@ Value AccumulatorMinMax::getValue(bool toBeMerged) const { return _val; } -AccumulatorMinMax::AccumulatorMinMax(Sense sense) : _sense(sense) { +AccumulatorMinMax::AccumulatorMinMax(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Sense sense) + : Accumulator(expCtx), _sense(sense) { _memUsageBytes = sizeof(*this); } @@ -77,11 +79,13 @@ void AccumulatorMinMax::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorMin::create() { - return new AccumulatorMin(); +intrusive_ptr<Accumulator> AccumulatorMin::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorMin(expCtx); } -intrusive_ptr<Accumulator> AccumulatorMax::create() { - return new AccumulatorMax(); +intrusive_ptr<Accumulator> AccumulatorMax::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorMax(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp index 16f722a08f3..9239d8b8a18 100644 --- a/src/mongo/db/pipeline/accumulator_push.cpp +++ b/src/mongo/db/pipeline/accumulator_push.cpp @@ -71,7 +71,8 @@ Value AccumulatorPush::getValue(bool toBeMerged) const { return Value(vpValue); } -AccumulatorPush::AccumulatorPush() { +AccumulatorPush::AccumulatorPush(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -80,7 +81,8 @@ void AccumulatorPush::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr<Accumulator> AccumulatorPush::create() { - return new AccumulatorPush(); +intrusive_ptr<Accumulator> AccumulatorPush::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorPush(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 28565250fe3..f45678dce27 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -95,15 +95,19 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) const { } } -intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create() { - return new AccumulatorStdDevSamp(); +intrusive_ptr<Accumulator> AccumulatorStdDevSamp::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorStdDevSamp(expCtx); } -intrusive_ptr<Accumulator> AccumulatorStdDevPop::create() { - return new AccumulatorStdDevPop(); +intrusive_ptr<Accumulator> AccumulatorStdDevPop::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorStdDevPop(expCtx); } -AccumulatorStdDev::AccumulatorStdDev(bool isSamp) : _isSamp(isSamp), _count(0), _mean(0), _m2(0) { +AccumulatorStdDev::AccumulatorStdDev(const boost::intrusive_ptr<ExpressionContext>& expCtx, + bool isSamp) + : Accumulator(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 13aa12fbfca..bbbf596d4a8 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -84,8 +84,9 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) { } } -intrusive_ptr<Accumulator> AccumulatorSum::create() { - return new AccumulatorSum(); +intrusive_ptr<Accumulator> AccumulatorSum::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new AccumulatorSum(expCtx); } Value AccumulatorSum::getValue(bool toBeMerged) const { @@ -133,7 +134,8 @@ Value AccumulatorSum::getValue(bool toBeMerged) const { } } -AccumulatorSum::AccumulatorSum() { +AccumulatorSum::AccumulatorSum(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Accumulator(expCtx) { // This is a fixed size Accumulator so we never need to update this. _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 29ab135ebb4..3fcb1d70ad7 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -32,7 +32,7 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/collation/collator_interface_mock.h" #include "mongo/dbtests/dbtests.h" #include "mongo/stdx/memory.h" @@ -57,8 +57,7 @@ static void assertExpectedResults( try { // Asserts that result equals expected result when not sharded. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); for (auto&& val : op.first) { accum->process(val, false); } @@ -69,10 +68,8 @@ static void assertExpectedResults( // Asserts that result equals expected result when all input is on one shard. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); - boost::intrusive_ptr<Accumulator> shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); + boost::intrusive_ptr<Accumulator> shard(factory(expCtx)); for (auto&& val : op.first) { shard->process(val, false); } @@ -84,11 +81,9 @@ static void assertExpectedResults( // Asserts that result equals expected result when each input is on a separate shard. { - boost::intrusive_ptr<Accumulator> accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> accum(factory(expCtx)); for (auto&& val : op.first) { - boost::intrusive_ptr<Accumulator> shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr<Accumulator> shard(factory(expCtx)); shard->process(val, false); accum->process(shard->getValue(true), true); } @@ -104,7 +99,7 @@ static void assertExpectedResults( } TEST(Accumulators, Avg) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$avg", expCtx, @@ -160,7 +155,7 @@ TEST(Accumulators, Avg) { } TEST(Accumulators, First) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$first", expCtx, @@ -179,7 +174,7 @@ TEST(Accumulators, First) { } TEST(Accumulators, Last) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$last", expCtx, @@ -198,7 +193,7 @@ TEST(Accumulators, Last) { } TEST(Accumulators, Min) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$min", expCtx, @@ -217,14 +212,14 @@ TEST(Accumulators, Min) { } TEST(Accumulators, MinRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$min", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("cba"_sd)}}); } TEST(Accumulators, Max) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$max", expCtx, @@ -243,14 +238,14 @@ TEST(Accumulators, Max) { } TEST(Accumulators, MaxRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$max", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("abc"_sd)}}); } TEST(Accumulators, Sum) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest()); assertExpectedResults( "$sum", expCtx, @@ -340,7 +335,7 @@ TEST(Accumulators, Sum) { } TEST(Accumulators, AddToSetRespectsCollation) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique<CollatorInterfaceMock>(CollatorInterfaceMock::MockType::kAlwaysEqual)); assertExpectedResults("$addToSet", diff --git a/src/mongo/db/pipeline/aggregation_context_fixture.h b/src/mongo/db/pipeline/aggregation_context_fixture.h index 75353c46e7f..786f50a9b5a 100644 --- a/src/mongo/db/pipeline/aggregation_context_fixture.h +++ b/src/mongo/db/pipeline/aggregation_context_fixture.h @@ -32,7 +32,7 @@ #include <memory> #include "mongo/db/client.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context_noop.h" #include "mongo/stdx/memory.h" @@ -48,16 +48,16 @@ public: AggregationContextFixture() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _expCtx(new ExpressionContext( + _expCtx(new ExpressionContextForTest( _opCtx.get(), AggregationRequest(NamespaceString("unittests.pipeline_test"), {}))) {} - boost::intrusive_ptr<ExpressionContext> getExpCtx() { + boost::intrusive_ptr<ExpressionContextForTest> getExpCtx() { return _expCtx.get(); } private: std::unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - boost::intrusive_ptr<ExpressionContext> _expCtx; + boost::intrusive_ptr<ExpressionContextForTest> _expCtx; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source.cpp b/src/mongo/db/pipeline/document_source.cpp index b83ec851ebe..440dc762175 100644 --- a/src/mongo/db/pipeline/document_source.cpp +++ b/src/mongo/db/pipeline/document_source.cpp @@ -61,7 +61,7 @@ void DocumentSource::registerParser(string name, Parser parser) { } vector<intrusive_ptr<DocumentSource>> DocumentSource::parse( - const intrusive_ptr<ExpressionContext> expCtx, BSONObj stageObj) { + const intrusive_ptr<ExpressionContext>& expCtx, BSONObj stageObj) { uassert(16435, "A pipeline stage specification object must contain exactly one field.", stageObj.nFields() == 1); diff --git a/src/mongo/db/pipeline/document_source.h b/src/mongo/db/pipeline/document_source.h index 81ba3fc61a6..bf4c65fa18b 100644 --- a/src/mongo/db/pipeline/document_source.h +++ b/src/mongo/db/pipeline/document_source.h @@ -269,22 +269,10 @@ public: virtual void reattachToOperationContext(OperationContext* opCtx) {} /** - * Injects a new ExpressionContext into this DocumentSource and propagates the ExpressionContext - * to all child expressions, accumulators, etc. - * - * Stages which require work to propagate the ExpressionContext to their private execution - * machinery should override doInjectExpressionContext(). - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - pExpCtx = expCtx; - doInjectExpressionContext(); - } - - /** * Create a DocumentSource pipeline stage from 'stageObj'. */ static std::vector<boost::intrusive_ptr<DocumentSource>> parse( - const boost::intrusive_ptr<ExpressionContext> expCtx, BSONObj stageObj); + const boost::intrusive_ptr<ExpressionContext>& expCtx, BSONObj stageObj); /** * Registers a DocumentSource with a parsing function, so that when a stage with the given name @@ -444,15 +432,6 @@ protected: explicit DocumentSource(const boost::intrusive_ptr<ExpressionContext>& pExpCtx); /** - * DocumentSources which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - * - * Any stage subclassing from DocumentSource should override this method if it contains - * expressions or accumulators which need to attach to the newly injected ExpressionContext. - */ - virtual void doInjectExpressionContext() {} - - /** * Attempt to perform an optimization with the following source in the pipeline. 'container' * refers to the entire pipeline, and 'itr' points to this stage within the pipeline. The caller * must guarantee that std::next(itr) != container->end(). diff --git a/src/mongo/db/pipeline/document_source_add_fields.cpp b/src/mongo/db/pipeline/document_source_add_fields.cpp index 0f7366538f9..445feaa1a51 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.cpp +++ b/src/mongo/db/pipeline/document_source_add_fields.cpp @@ -49,8 +49,7 @@ intrusive_ptr<DocumentSource> DocumentSourceAddFields::create( BSONObj addFieldsSpec, const intrusive_ptr<ExpressionContext>& expCtx) { intrusive_ptr<DocumentSourceSingleDocumentTransformation> addFields( new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAddFields::create(addFieldsSpec), "$addFields")); - addFields->injectExpressionContext(expCtx); + expCtx, ParsedAddFields::create(expCtx, addFieldsSpec), "$addFields")); return addFields; } diff --git a/src/mongo/db/pipeline/document_source_add_fields.h b/src/mongo/db/pipeline/document_source_add_fields.h index b1ce0415ab3..77fe4d9d4e9 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.h +++ b/src/mongo/db/pipeline/document_source_add_fields.h @@ -51,6 +51,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); private: + // It is illegal to construct a DocumentSourceAddFields directly, use create() or + // createFromBson() instead. DocumentSourceAddFields() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket.cpp b/src/mongo/db/pipeline/document_source_bucket.cpp index 05152b7e313..7c671ea2998 100644 --- a/src/mongo/db/pipeline/document_source_bucket.cpp +++ b/src/mongo/db/pipeline/document_source_bucket.cpp @@ -43,9 +43,11 @@ REGISTER_MULTI_STAGE_ALIAS(bucket, DocumentSourceBucket::createFromBson); namespace { -intrusive_ptr<ExpressionConstant> getExpressionConstant(BSONElement expressionElem, - VariablesParseState vps) { - auto expr = Expression::parseOperand(expressionElem, vps)->optimize(); +intrusive_ptr<ExpressionConstant> getExpressionConstant( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expressionElem, + VariablesParseState vps) { + auto expr = Expression::parseOperand(expCtx, expressionElem, vps)->optimize(); return dynamic_cast<ExpressionConstant*>(expr.get()); } } // namespace @@ -95,7 +97,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( argument.type() == BSONType::Array); for (auto&& boundaryElem : argument.embeddedObject()) { - auto exprConst = getExpressionConstant(boundaryElem, vps); + auto exprConst = getExpressionConstant(pExpCtx, boundaryElem, vps); uassert(40191, str::stream() << "The $bucket 'boundaries' field must be an array of " "constant values, but found value: " @@ -144,7 +146,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( } else if ("default" == argName) { // If there is a default, make sure that it parses to a constant expression then add // default to switch. - auto exprConst = getExpressionConstant(argument, vps); + auto exprConst = getExpressionConstant(pExpCtx, argument, vps); uassert(40195, str::stream() << "The $bucket 'default' field must be a constant expression, but found: " diff --git a/src/mongo/db/pipeline/document_source_bucket.h b/src/mongo/db/pipeline/document_source_bucket.h index cd7fe31b14e..aa83ee07cc2 100644 --- a/src/mongo/db/pipeline/document_source_bucket.h +++ b/src/mongo/db/pipeline/document_source_bucket.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceBucket directly, use createFromBson() instead. DocumentSourceBucket() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.cpp b/src/mongo/db/pipeline/document_source_bucket_auto.cpp index c9a09354434..b02e8d0226f 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto.cpp @@ -196,7 +196,8 @@ void DocumentSourceBucketAuto::populateBuckets() { } // Initialize the current bucket. - Bucket currentBucket(currentValue.first, currentValue.first, _accumulatorFactories); + Bucket currentBucket( + pExpCtx, currentValue.first, currentValue.first, _accumulatorFactories); // Add the first value into the current bucket. addDocumentToBucket(currentValue, currentBucket); @@ -267,13 +268,14 @@ void DocumentSourceBucketAuto::populateBuckets() { } } -DocumentSourceBucketAuto::Bucket::Bucket(Value min, +DocumentSourceBucketAuto::Bucket::Bucket(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Value min, Value max, vector<Accumulator::Factory> accumulatorFactories) : _min(min), _max(max) { _accums.reserve(accumulatorFactories.size()); for (auto&& factory : accumulatorFactories) { - _accums.push_back(factory()); + _accums.push_back(factory(expCtx)); } } @@ -344,7 +346,7 @@ Value DocumentSourceBucketAuto::serialize(bool explain) const { const size_t nOutputFields = _fieldNames.size(); MutableDocument outputSpec(nOutputFields); for (size_t i = 0; i < nOutputFields; i++) { - intrusive_ptr<Accumulator> accum = _accumulatorFactories[i](); + intrusive_ptr<Accumulator> accum = _accumulatorFactories[i](pExpCtx); outputSpec[_fieldNames[i]] = Value{Document{{accum->getOpName(), _expressions[i]->serialize(explain)}}}; } @@ -405,14 +407,16 @@ DocumentSourceBucketAuto::DocumentSourceBucketAuto( namespace { -boost::intrusive_ptr<Expression> parseGroupByExpression(const BSONElement& groupByField, - const VariablesParseState& vps) { +boost::intrusive_ptr<Expression> parseGroupByExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONElement& groupByField, + const VariablesParseState& vps) { if (groupByField.type() == BSONType::Object && groupByField.embeddedObject().firstElementFieldName()[0] == '$') { - return Expression::parseObject(groupByField.embeddedObject(), vps); + return Expression::parseObject(expCtx, groupByField.embeddedObject(), vps); } else if (groupByField.type() == BSONType::String && groupByField.valueStringData()[0] == '$') { - return ExpressionFieldPath::parse(groupByField.str(), vps); + return ExpressionFieldPath::parse(expCtx, groupByField.str(), vps); } else { uasserted( 40239, @@ -440,7 +444,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( for (auto&& argument : elem.Obj()) { const auto argName = argument.fieldNameStringData(); if ("groupBy" == argName) { - groupByExpression = parseGroupByExpression(argument, vps); + groupByExpression = parseGroupByExpression(pExpCtx, argument, vps); } else if ("buckets" == argName) { Value bucketsValue = Value(argument); @@ -467,7 +471,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( for (auto&& outputField : argument.embeddedObject()) { accumulationStatements.push_back( - AccumulationStatement::parseAccumulationStatement(outputField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps)); } } else if ("granularity" == argName) { uassert(40261, @@ -475,7 +479,7 @@ intrusive_ptr<DocumentSource> DocumentSourceBucketAuto::createFromBson( << "The $bucketAuto 'granularity' field must be a string, but found type: " << typeName(argument.type()), argument.type() == BSONType::String); - granularityRounder = GranularityRounder::getGranularityRounder(argument.str()); + granularityRounder = GranularityRounder::getGranularityRounder(pExpCtx, argument.str()); } else { uasserted(40245, str::stream() << "Unrecognized option to $bucketAuto: " << argName); } diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.h b/src/mongo/db/pipeline/document_source_bucket_auto.h index 9279bcd52b9..61f7da30c71 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.h +++ b/src/mongo/db/pipeline/document_source_bucket_auto.h @@ -93,7 +93,10 @@ private: // struct for holding information about a bucket. struct Bucket { - Bucket(Value min, Value max, std::vector<Accumulator::Factory> accumulatorFactories); + Bucket(const boost::intrusive_ptr<ExpressionContext>& expCtx, + Value min, + Value max, + std::vector<Accumulator::Factory> accumulatorFactories); Value _min; Value _max; std::vector<boost::intrusive_ptr<Accumulator>> _accums; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp index 530948d2d3d..2d71b676024 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp @@ -355,7 +355,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToCorrectlySpillToDisk) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -397,7 +397,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -643,7 +643,7 @@ void assertCannotSpillToDisk(const boost::intrusive_ptr<ExpressionContext>& expC VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -685,7 +685,7 @@ TEST_F(BucketAutoTests, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, diff --git a/src/mongo/db/pipeline/document_source_bucket_test.cpp b/src/mongo/db/pipeline/document_source_bucket_test.cpp index b93916f4347..bdfb3635bb7 100644 --- a/src/mongo/db/pipeline/document_source_bucket_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_test.cpp @@ -156,9 +156,6 @@ class InvalidBucketSpec : public AggregationContextFixture { public: vector<intrusive_ptr<DocumentSource>> createBucket(BSONObj bucketSpec) { auto sources = DocumentSourceBucket::createFromBson(bucketSpec.firstElement(), getExpCtx()); - for (auto&& source : sources) { - source->injectExpressionContext(getExpCtx()); - } return sources; } }; diff --git a/src/mongo/db/pipeline/document_source_count.h b/src/mongo/db/pipeline/document_source_count.h index e1817fbfca3..50fbccdf41b 100644 --- a/src/mongo/db/pipeline/document_source_count.h +++ b/src/mongo/db/pipeline/document_source_count.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceCount directly, use createFromBson() instead. DocumentSourceCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_cursor.cpp b/src/mongo/db/pipeline/document_source_cursor.cpp index 5c5be2aef1f..2def98d25bd 100644 --- a/src/mongo/db/pipeline/document_source_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_cursor.cpp @@ -221,12 +221,6 @@ Value DocumentSourceCursor::serialize(bool explain) const { return Value(DOC(getSourceName() << out.freezeToValue())); } -void DocumentSourceCursor::doInjectExpressionContext() { - if (_limit) { - _limit->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceCursor::detachFromOperationContext() { if (_exec) { _exec->detachFromOperationContext(); @@ -259,7 +253,6 @@ intrusive_ptr<DocumentSourceCursor> DocumentSourceCursor::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { intrusive_ptr<DocumentSourceCursor> source( new DocumentSourceCursor(ns, std::move(exec), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_cursor.h b/src/mongo/db/pipeline/document_source_cursor.h index ebbf9eff64d..799834f1487 100644 --- a/src/mongo/db/pipeline/document_source_cursor.h +++ b/src/mongo/db/pipeline/document_source_cursor.h @@ -132,9 +132,6 @@ public: const PlanSummaryStats& getPlanSummaryStats() const; -protected: - void doInjectExpressionContext() final; - private: DocumentSourceCursor(const std::string& ns, std::unique_ptr<PlanExecutor> exec, diff --git a/src/mongo/db/pipeline/document_source_facet.cpp b/src/mongo/db/pipeline/document_source_facet.cpp index 7b8cd82de0f..9b03b233051 100644 --- a/src/mongo/db/pipeline/document_source_facet.cpp +++ b/src/mongo/db/pipeline/document_source_facet.cpp @@ -203,12 +203,6 @@ intrusive_ptr<DocumentSource> DocumentSourceFacet::optimize() { return this; } -void DocumentSourceFacet::doInjectExpressionContext() { - for (auto&& facet : _facets) { - facet.pipeline->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceFacet::doInjectMongodInterface(std::shared_ptr<MongodInterface> mongod) { for (auto&& facet : _facets) { for (auto&& stage : facet.pipeline->getSources()) { diff --git a/src/mongo/db/pipeline/document_source_facet.h b/src/mongo/db/pipeline/document_source_facet.h index b0a9a0867ee..921b2e87609 100644 --- a/src/mongo/db/pipeline/document_source_facet.h +++ b/src/mongo/db/pipeline/document_source_facet.h @@ -43,7 +43,7 @@ namespace mongo { class BSONElement; class TeeBuffer; class DocumentSourceTeeConsumer; -struct ExpressionContext; +class ExpressionContext; class NamespaceString; /** @@ -98,11 +98,6 @@ public: boost::intrusive_ptr<DocumentSource> optimize() final; /** - * Injects the expression context into inner pipelines. - */ - void doInjectExpressionContext() final; - - /** * Takes a union of all sub-pipelines, and adds them to 'deps'. */ GetDepsReturn getDependencies(DepsTracker* deps) const final; diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp index 8c4a2ca3ff5..17e9997bbf5 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near.cpp @@ -178,7 +178,6 @@ void DocumentSourceGeoNear::runCommand() { intrusive_ptr<DocumentSourceGeoNear> DocumentSourceGeoNear::create( const intrusive_ptr<ExpressionContext>& pCtx) { intrusive_ptr<DocumentSourceGeoNear> source(new DocumentSourceGeoNear(pCtx)); - source->injectExpressionContext(pCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.cpp b/src/mongo/db/pipeline/document_source_graph_lookup.cpp index 7743e9c6058..03515560057 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup.cpp @@ -164,7 +164,7 @@ DocumentSource::GetNextResult DocumentSourceGraphLookUp::getNextUnwound() { void DocumentSourceGraphLookUp::dispose() { _cache.clear(); - _frontier->clear(); + _frontier.clear(); _visited.clear(); pSource->dispose(); } @@ -180,7 +180,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { auto matchStage = makeMatchStageFromFrontier(&cached); ValueUnorderedSet queried = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _frontier->swap(queried); + _frontier.swap(queried); _frontierUsageBytes = 0; // Process cached values, populating '_frontier' for the next iteration of search. @@ -219,7 +219,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { } while (shouldPerformAnotherQuery && depth < std::numeric_limits<long long>::max() && (!_maxDepth || depth <= *_maxDepth)); - _frontier->clear(); + _frontier.clear(); _frontierUsageBytes = 0; } @@ -264,12 +264,12 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon Value recurseOn = Value(elem); if (recurseOn.isArray()) { for (auto&& subElem : recurseOn.getArray()) { - _frontier->insert(subElem); + _frontier.insert(subElem); _frontierUsageBytes += subElem.getApproximateSize(); } } else if (!recurseOn.missing()) { // Don't recurse on a missing value. - _frontier->insert(recurseOn); + _frontier.insert(recurseOn); _frontierUsageBytes += recurseOn.getApproximateSize(); } } @@ -304,13 +304,13 @@ void DocumentSourceGraphLookUp::addToCache(const BSONObj& result, boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(BSONObjSet* cached) { // Add any cached values to 'cached' and remove them from '_frontier'. - for (auto it = _frontier->begin(); it != _frontier->end();) { + for (auto it = _frontier.begin(); it != _frontier.end();) { if (auto entry = _cache[*it]) { for (auto&& obj : *entry) { cached->insert(obj); } size_t valueSize = it->getApproximateSize(); - it = _frontier->erase(it); + it = _frontier.erase(it); // If the cached value increased in size while in the cache, we don't want to underflow // '_frontierUsageBytes'. @@ -340,7 +340,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B BSONObjBuilder subObj(connectToObj.subobjStart(_connectToField.fullPath())); { BSONArrayBuilder in(subObj.subarrayStart("$in")); - for (auto&& value : *_frontier) { + for (auto&& value : _frontier) { in << value; } } @@ -349,7 +349,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B } } - return _frontier->empty() ? boost::none : boost::optional<BSONObj>(match.obj()); + return _frontier.empty() ? boost::none : boost::optional<BSONObj>(match.obj()); } void DocumentSourceGraphLookUp::performSearch() { @@ -363,11 +363,11 @@ void DocumentSourceGraphLookUp::performSearch() { // If _startWith evaluates to an array, treat each value as a separate starting point. if (startingValue.isArray()) { for (auto value : startingValue.getArray()) { - _frontier->insert(value); + _frontier.insert(value); _frontierUsageBytes += value.getApproximateSize(); } } else { - _frontier->insert(startingValue); + _frontier.insert(startingValue); _frontierUsageBytes += startingValue.getApproximateSize(); } @@ -460,24 +460,6 @@ void DocumentSourceGraphLookUp::serializeToArray(std::vector<Value>& array, bool } } -void DocumentSourceGraphLookUp::doInjectExpressionContext() { - _startWith->injectExpressionContext(pExpCtx); - - auto it = pExpCtx->resolvedNamespaces.find(_from.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); - - _frontier = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _cache.setValueComparator(pExpCtx->getValueComparator()); -} - void DocumentSourceGraphLookUp::doDetachFromOperationContext() { _fromExpCtx->opCtx = nullptr; } @@ -506,9 +488,19 @@ DocumentSourceGraphLookUp::DocumentSourceGraphLookUp( _additionalFilter(additionalFilter), _depthField(depthField), _maxDepth(maxDepth), + _frontier(pExpCtx->getValueComparator().makeUnorderedValueSet()), _visited(ValueComparator::kInstance.makeUnorderedValueMap<BSONObj>()), - _cache(expCtx->getValueComparator()), - _unwind(unwindSrc) {} + _cache(pExpCtx->getValueComparator()), + _unwind(unwindSrc) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_from); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} intrusive_ptr<DocumentSourceGraphLookUp> DocumentSourceGraphLookUp::create( const intrusive_ptr<ExpressionContext>& expCtx, @@ -533,8 +525,6 @@ intrusive_ptr<DocumentSourceGraphLookUp> DocumentSourceGraphLookUp::create( maxDepth, unwindSrc)); source->_variables.reset(new Variables()); - - source->injectExpressionContext(expCtx); return source; } @@ -556,7 +546,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGraphLookUp::createFromBson( const auto argName = argument.fieldNameStringData(); if (argName == "startWith") { - startWith = Expression::parseOperand(argument, vps); + startWith = Expression::parseOperand(expCtx, argument, vps); continue; } else if (argName == "maxDepth") { uassert(40100, diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.h b/src/mongo/db/pipeline/document_source_graph_lookup.h index 01473bf44fc..be12637d404 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.h +++ b/src/mongo/db/pipeline/document_source_graph_lookup.h @@ -94,9 +94,6 @@ public: static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); -protected: - void doInjectExpressionContext() final; - private: DocumentSourceGraphLookUp( const boost::intrusive_ptr<ExpressionContext>& expCtx, @@ -188,9 +185,7 @@ private: size_t _frontierUsageBytes = 0; // Only used during the breadth-first search, tracks the set of values on the current frontier. - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected. - boost::optional<ValueUnorderedSet> _frontier; + ValueUnorderedSet _frontier; // Tracks nodes that have been discovered for a given input. Keys are the '_id' value of the // document from the foreign collection, value is the document itself. The keys are compared diff --git a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp index 6350a8c39e1..53e3e6ea1ac 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp @@ -70,7 +70,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_results)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -90,17 +89,18 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque<DocumentSource::GetNextResult> fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -119,17 +119,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"to", 1}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -147,18 +148,19 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque<DocumentSource::GetNextResult> fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, boost::none); - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - unwindStage); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + unwindStage); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); graphLookupStage->setSource(inputMock.get()); @@ -190,17 +192,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document(to1), Document(to2), Document(to0from1), Document(to0from2)}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared<MockMongodImplementation>(std::move(fromContents))); @@ -254,14 +257,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePauses) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -322,7 +325,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); const bool preserveNullAndEmptyArrays = false; const boost::optional<std::string> includeArrayIndex = boost::none; @@ -335,7 +338,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -387,14 +390,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -409,7 +412,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, std::string("arrIndex")); auto graphLookupStage = @@ -418,7 +421,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbs "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -437,7 +440,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar auto inputMock = DocumentSourceMock::create(Document({{"_id", 0}, {"a", 1}, {"b", 2}})); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); std::deque<DocumentSource::GetNextResult> fromContents{Document{{"_id", 0}, {"to", true}}, Document{{"_id", 1}, {"to", false}}}; @@ -447,9 +450,10 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar "results", "from", "to", - ExpressionCompare::create(ExpressionCompare::GT, - ExpressionFieldPath::create("a"), - ExpressionFieldPath::create("b")), + ExpressionCompare::create(expCtx, + ExpressionCompare::GT, + ExpressionFieldPath::create(expCtx, "a"), + ExpressionFieldPath::create(expCtx, "b")), boost::none, boost::none, boost::none, diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp index cf91c8eff5f..a70a65c06b9 100644 --- a/src/mongo/db/pipeline/document_source_group.cpp +++ b/src/mongo/db/pipeline/document_source_group.cpp @@ -198,23 +198,6 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() { return this; } -void DocumentSourceGroup::doInjectExpressionContext() { - // Groups map must respect new comparator. - _groups = pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>(); - - for (auto&& idExpr : _idExpressions) { - idExpr->injectExpressionContext(pExpCtx); - } - - for (auto&& expr : vpExpression) { - expr->injectExpressionContext(pExpCtx); - } - - for (auto&& accum : _currentAccumulators) { - accum->injectExpressionContext(pExpCtx); - } -} - Value DocumentSourceGroup::serialize(bool explain) const { MutableDocument insides; @@ -235,7 +218,7 @@ Value DocumentSourceGroup::serialize(bool explain) const { // Add the remaining fields. const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { - intrusive_ptr<Accumulator> accum = vpAccumulatorFactory[i](); + intrusive_ptr<Accumulator> accum = vpAccumulatorFactory[i](pExpCtx); insides[vFieldName[i]] = Value(DOC(accum->getOpName() << vpExpression[i]->serialize(explain))); } @@ -280,7 +263,6 @@ intrusive_ptr<DocumentSourceGroup> DocumentSourceGroup::create( groupStage->addAccumulator(statement); } groupStage->_variables = stdx::make_unique<Variables>(numVariables); - groupStage->injectExpressionContext(pExpCtx); return groupStage; } @@ -292,6 +274,7 @@ DocumentSourceGroup::DocumentSourceGroup(const intrusive_ptr<ExpressionContext>& _inputSort(BSONObj()), _streaming(false), _initialized(false), + _groups(pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>()), _spilled(false), _extSortAllowed(pExpCtx->extSortAllowed && !pExpCtx->inRouter) {} @@ -303,7 +286,7 @@ void DocumentSourceGroup::addAccumulator(AccumulationStatement accumulationState namespace { -intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContext> expCtx, +intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContext>& expCtx, BSONElement groupField, const VariablesParseState& vps) { if (groupField.type() == Object && !groupField.Obj().isEmpty()) { @@ -312,18 +295,18 @@ intrusive_ptr<Expression> parseIdExpression(const intrusive_ptr<ExpressionContex const BSONObj idKeyObj = groupField.Obj(); if (idKeyObj.firstElementFieldName()[0] == '$') { // grouping on a $op expression - return Expression::parseObject(idKeyObj, vps); + return Expression::parseObject(expCtx, idKeyObj, vps); } else { for (auto&& field : idKeyObj) { uassert(17390, "$group does not support inclusion-style expressions", !field.isNumber() && field.type() != Bool); } - return ExpressionObject::parse(idKeyObj, vps); + return ExpressionObject::parse(expCtx, idKeyObj, vps); } } else if (groupField.type() == String && groupField.valuestr()[0] == '$') { // grouping on a field path. - return ExpressionFieldPath::parse(groupField.str(), vps); + return ExpressionFieldPath::parse(expCtx, groupField.str(), vps); } else { // constant id - single group return ExpressionConstant::create(expCtx, Value(groupField)); @@ -377,7 +360,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::createFromBson( } else { // Any other field will be treated as an accumulator specification. pGroup->addAccumulator( - AccumulationStatement::parseAccumulationStatement(groupField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, groupField, vps)); } } @@ -493,8 +476,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Set up accumulators. _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } // We only need to load the first document. @@ -543,8 +525,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Add the accumulators group.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - group.push_back(vpAccumulatorFactory[i]()); - group.back()->injectExpressionContext(pExpCtx); + group.push_back(vpAccumulatorFactory[i](pExpCtx)); } } else { for (size_t i = 0; i < numAccumulators; i++) { @@ -599,8 +580,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // prepare current to accumulate data _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } verify(_sorterIterator->more()); // we put data in, we should get something out. @@ -876,7 +856,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::getMergeSource() { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); /* the merger will use the same grouping key */ - pMerger->setIdExpression(ExpressionFieldPath::parse("$$ROOT._id", vps)); + pMerger->setIdExpression(ExpressionFieldPath::parse(pExpCtx, "$$ROOT._id", vps)); const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { @@ -888,13 +868,13 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::getMergeSource() { expression or constant. Here, we accumulate the output of the same name from the prior group. */ - pMerger->addAccumulator({vFieldName[i], - vpAccumulatorFactory[i], - ExpressionFieldPath::parse("$$ROOT." + vFieldName[i], vps)}); + pMerger->addAccumulator( + {vFieldName[i], + vpAccumulatorFactory[i], + ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + vFieldName[i], vps)}); } pMerger->_variables.reset(new Variables(idGenerator.getIdCount())); - pMerger->injectExpressionContext(pExpCtx); return pMerger; } diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h index 76d9282ebb9..e79ce631817 100644 --- a/src/mongo/db/pipeline/document_source_group.h +++ b/src/mongo/db/pipeline/document_source_group.h @@ -96,9 +96,6 @@ public: boost::intrusive_ptr<DocumentSource> getShardSource() final; boost::intrusive_ptr<DocumentSource> getMergeSource() final; -protected: - void doInjectExpressionContext() final; - private: explicit DocumentSourceGroup(const boost::intrusive_ptr<ExpressionContext>& pExpCtx, size_t maxMemoryUsageBytes = kDefaultMaxMemoryUsageBytes); diff --git a/src/mongo/db/pipeline/document_source_group_test.cpp b/src/mongo/db/pipeline/document_source_group_test.cpp index 599e282e6b9..25b10ddcb5b 100644 --- a/src/mongo/db/pipeline/document_source_group_test.cpp +++ b/src/mongo/db/pipeline/document_source_group_test.cpp @@ -45,7 +45,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/dbtests/dbtests.h" @@ -110,8 +110,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -150,8 +150,8 @@ TEST_F(DocumentSourceGroupTest, ShouldErrorIfNotAllowedToSpillToDiskAndResultSet VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -173,8 +173,8 @@ TEST_F(DocumentSourceGroupTest, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -204,7 +204,8 @@ public: Base() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))), + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))), _tempDir("DocumentSourceGroupTest") {} protected: @@ -212,15 +213,14 @@ protected: BSONObj namedSpec = BSON("$group" << spec); BSONElement specElement = namedSpec.firstElement(); - intrusive_ptr<ExpressionContext> expressionContext = - new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); + intrusive_ptr<ExpressionContextForTest> expressionContext = + new ExpressionContextForTest(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); expressionContext->inShard = inShard; expressionContext->inRouter = inRouter; // Won't spill to disk properly if it needs to. expressionContext->tempDir = _tempDir.path(); _group = DocumentSourceGroup::createFromBson(specElement, expressionContext); - _group->injectExpressionContext(expressionContext); assertRoundTrips(_group); } DocumentSourceGroup* group() { @@ -234,7 +234,7 @@ protected: ASSERT(source->getNext().isEOF()); } - intrusive_ptr<ExpressionContext> ctx() const { + intrusive_ptr<ExpressionContextForTest> ctx() const { return _ctx; } @@ -251,7 +251,7 @@ private: } std::unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr<ExpressionContext> _ctx; + intrusive_ptr<ExpressionContextForTest> _ctx; intrusive_ptr<DocumentSource> _group; TempDir _tempDir; }; diff --git a/src/mongo/db/pipeline/document_source_limit.cpp b/src/mongo/db/pipeline/document_source_limit.cpp index 634d635b22d..797c6ed6652 100644 --- a/src/mongo/db/pipeline/document_source_limit.cpp +++ b/src/mongo/db/pipeline/document_source_limit.cpp @@ -93,7 +93,6 @@ intrusive_ptr<DocumentSourceLimit> DocumentSourceLimit::create( const intrusive_ptr<ExpressionContext>& pExpCtx, long long limit) { uassert(15958, "the limit must be positive", limit > 0); intrusive_ptr<DocumentSourceLimit> source(new DocumentSourceLimit(pExpCtx, limit)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_lookup.cpp b/src/mongo/db/pipeline/document_source_lookup.cpp index 5371542505c..a0923db276e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_lookup.cpp @@ -54,7 +54,16 @@ DocumentSourceLookUp::DocumentSourceLookUp(NamespaceString fromNs, _as(std::move(as)), _localField(std::move(localField)), _foreignField(foreignField), - _foreignFieldFieldName(std::move(foreignField)) {} + _foreignFieldFieldName(std::move(foreignField)) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_fromNs); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} std::unique_ptr<LiteParsedDocumentSourceOneForeignCollection> DocumentSourceLookUp::liteParse( const AggregationRequest& request, const BSONElement& spec) { @@ -454,19 +463,6 @@ DocumentSource::GetDepsReturn DocumentSourceLookUp::getDependencies(DepsTracker* return SEE_NEXT; } -void DocumentSourceLookUp::doInjectExpressionContext() { - auto it = pExpCtx->resolvedNamespaces.find(_fromNs.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); -} - void DocumentSourceLookUp::doDetachFromOperationContext() { if (_pipeline) { // We have a pipeline we're going to be executing across multiple calls to getNext(), so we diff --git a/src/mongo/db/pipeline/document_source_lookup.h b/src/mongo/db/pipeline/document_source_lookup.h index 13118828b65..e179abfe73e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.h +++ b/src/mongo/db/pipeline/document_source_lookup.h @@ -113,9 +113,6 @@ public: _handlingUnwind = true; } -protected: - void doInjectExpressionContext() final; - private: DocumentSourceLookUp(NamespaceString fromNs, std::string as, diff --git a/src/mongo/db/pipeline/document_source_lookup_test.cpp b/src/mongo/db/pipeline/document_source_lookup_test.cpp index 0f54e3741c4..9da414ec6f6 100644 --- a/src/mongo/db/pipeline/document_source_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_lookup_test.cpp @@ -54,6 +54,10 @@ using std::vector; using DocumentSourceLookUpTest = AggregationContextFixture; TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + intrusive_ptr<DocumentSourceMock> source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -63,7 +67,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { {"as", "d.e"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -73,6 +77,10 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { } TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); + intrusive_ptr<DocumentSourceMock> source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -82,7 +90,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { {"as", "d"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -158,7 +166,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_mockResults)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -171,7 +178,7 @@ private: TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -191,7 +198,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); // Mock out the foreign collection. deque<DocumentSource::GetNextResult> mockForeignContents{Document{{"_id", 0}}, @@ -222,7 +228,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -247,8 +253,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); - // Mock out the foreign collection. deque<DocumentSource::GetNextResult> mockForeignContents{Document{{"_id", 0}}, Document{{"_id", 1}}}; @@ -276,7 +280,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -297,7 +301,7 @@ TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { TEST_F(DocumentSourceLookUpTest, LookupReportsFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector<BSONObj>{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector<BSONObj>{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", diff --git a/src/mongo/db/pipeline/document_source_match.cpp b/src/mongo/db/pipeline/document_source_match.cpp index 18cb22db71d..494d6589099 100644 --- a/src/mongo/db/pipeline/document_source_match.cpp +++ b/src/mongo/db/pipeline/document_source_match.cpp @@ -401,7 +401,7 @@ DocumentSourceMatch::splitSourceBy(const std::set<std::string>& fields) { boost::intrusive_ptr<DocumentSourceMatch> DocumentSourceMatch::descendMatchOnPath( MatchExpression* matchExpr, const std::string& descendOn, - intrusive_ptr<ExpressionContext> expCtx) { + const intrusive_ptr<ExpressionContext>& expCtx) { expression::mapOver(matchExpr, [&descendOn](MatchExpression* node, std::string path) -> void { // Cannot call this method on a $match including a $elemMatch. invariant(node->matchType() != MatchExpression::ELEM_MATCH_OBJECT && @@ -453,7 +453,6 @@ intrusive_ptr<DocumentSourceMatch> DocumentSourceMatch::create( BSONObj filter, const intrusive_ptr<ExpressionContext>& expCtx) { uassertNoDisallowedClauses(filter); intrusive_ptr<DocumentSourceMatch> match(new DocumentSourceMatch(filter, expCtx)); - match->injectExpressionContext(expCtx); return match; } @@ -490,10 +489,6 @@ void DocumentSourceMatch::addDependencies(DepsTracker* deps) const { }); } -void DocumentSourceMatch::doInjectExpressionContext() { - _expression->setCollator(pExpCtx->getCollator()); -} - DocumentSourceMatch::DocumentSourceMatch(const BSONObj& query, const intrusive_ptr<ExpressionContext>& pExpCtx) : DocumentSource(pExpCtx), _predicate(query.getOwned()), _isTextQuery(isTextQuery(query)) { diff --git a/src/mongo/db/pipeline/document_source_match.h b/src/mongo/db/pipeline/document_source_match.h index e7f3d4227d9..badfe1ed8b9 100644 --- a/src/mongo/db/pipeline/document_source_match.h +++ b/src/mongo/db/pipeline/document_source_match.h @@ -140,9 +140,7 @@ public: static boost::intrusive_ptr<DocumentSourceMatch> descendMatchOnPath( MatchExpression* matchExpr, const std::string& path, - boost::intrusive_ptr<ExpressionContext> expCtx); - - void doInjectExpressionContext(); + const boost::intrusive_ptr<ExpressionContext>& expCtx); private: DocumentSourceMatch(const BSONObj& query, diff --git a/src/mongo/db/pipeline/document_source_merge_cursors.cpp b/src/mongo/db/pipeline/document_source_merge_cursors.cpp index 866259c663e..b03db6560df 100644 --- a/src/mongo/db/pipeline/document_source_merge_cursors.cpp +++ b/src/mongo/db/pipeline/document_source_merge_cursors.cpp @@ -57,7 +57,6 @@ intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { intrusive_ptr<DocumentSourceMergeCursors> source( new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_mock.cpp b/src/mongo/db/pipeline/document_source_mock.cpp index 686f4ef6f2e..3918bbd5ad7 100644 --- a/src/mongo/db/pipeline/document_source_mock.cpp +++ b/src/mongo/db/pipeline/document_source_mock.cpp @@ -31,6 +31,8 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" namespace mongo { @@ -38,7 +40,7 @@ using boost::intrusive_ptr; using std::deque; DocumentSourceMock::DocumentSourceMock(deque<GetNextResult> results) - : DocumentSource(nullptr), + : DocumentSource(new ExpressionContextForTest()), queue(std::move(results)), sorts(SimpleBSONObjComparator::kInstance.makeBSONObjSet()) {} diff --git a/src/mongo/db/pipeline/document_source_mock.h b/src/mongo/db/pipeline/document_source_mock.h index 7a600483661..7b966f5c418 100644 --- a/src/mongo/db/pipeline/document_source_mock.h +++ b/src/mongo/db/pipeline/document_source_mock.h @@ -79,10 +79,6 @@ public: return this; } - void doInjectExpressionContext() override { - isExpCtxInjected = true; - } - // Return documents from front of queue. std::deque<GetNextResult> queue; diff --git a/src/mongo/db/pipeline/document_source_project.cpp b/src/mongo/db/pipeline/document_source_project.cpp index fd44d534dde..a0ee096f36e 100644 --- a/src/mongo/db/pipeline/document_source_project.cpp +++ b/src/mongo/db/pipeline/document_source_project.cpp @@ -49,8 +49,7 @@ REGISTER_DOCUMENT_SOURCE(project, intrusive_ptr<DocumentSource> DocumentSourceProject::create( BSONObj projectSpec, const intrusive_ptr<ExpressionContext>& expCtx) { intrusive_ptr<DocumentSource> project(new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAggregationProjection::create(projectSpec), "$project")); - project->injectExpressionContext(expCtx); + expCtx, ParsedAggregationProjection::create(expCtx, projectSpec), "$project")); return project; } diff --git a/src/mongo/db/pipeline/document_source_project.h b/src/mongo/db/pipeline/document_source_project.h index 869dce512b1..4ed9db432d8 100644 --- a/src/mongo/db/pipeline/document_source_project.h +++ b/src/mongo/db/pipeline/document_source_project.h @@ -53,6 +53,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceProject directly, use create() or createFromBson() + // instead. DocumentSourceProject() = default; }; diff --git a/src/mongo/db/pipeline/document_source_redact.cpp b/src/mongo/db/pipeline/document_source_redact.cpp index 86b1c00960d..0096017a2f8 100644 --- a/src/mongo/db/pipeline/document_source_redact.cpp +++ b/src/mongo/db/pipeline/document_source_redact.cpp @@ -163,10 +163,6 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::optimize() { return this; } -void DocumentSourceRedact::doInjectExpressionContext() { - _expression->injectExpressionContext(pExpCtx); -} - Value DocumentSourceRedact::serialize(bool explain) const { return Value(DOC(getSourceName() << _expression.get()->serialize(explain))); } @@ -179,7 +175,7 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::createFromBson( Variables::Id decendId = vps.defineVariable("DESCEND"); Variables::Id pruneId = vps.defineVariable("PRUNE"); Variables::Id keepId = vps.defineVariable("KEEP"); - intrusive_ptr<Expression> expression = Expression::parseOperand(elem, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, elem, vps); intrusive_ptr<DocumentSourceRedact> source = new DocumentSourceRedact(expCtx, expression); // TODO figure out how much of this belongs in constructor and how much here. diff --git a/src/mongo/db/pipeline/document_source_redact.h b/src/mongo/db/pipeline/document_source_redact.h index f056411f765..b9cf6242d42 100644 --- a/src/mongo/db/pipeline/document_source_redact.h +++ b/src/mongo/db/pipeline/document_source_redact.h @@ -47,8 +47,6 @@ public: Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); diff --git a/src/mongo/db/pipeline/document_source_replace_root.cpp b/src/mongo/db/pipeline/document_source_replace_root.cpp index 6f1844ed281..1ae82d5cf3a 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.cpp +++ b/src/mongo/db/pipeline/document_source_replace_root.cpp @@ -87,17 +87,14 @@ public: return DocumentSource::EXHAUSTIVE_FIELDS; } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& pExpCtx) final { - _newRoot->injectExpressionContext(pExpCtx); - } - DocumentSource::GetModPathsReturn getModifiedPaths() const final { // Replaces the entire root, so all paths are modified. return {DocumentSource::GetModPathsReturn::Type::kAllPaths, std::set<std::string>{}}; } // Create the replaceRoot transformer. Uasserts on invalid input. - static std::unique_ptr<ReplaceRootTransformation> create(const BSONElement& spec) { + static std::unique_ptr<ReplaceRootTransformation> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONElement& spec) { // Confirm that the stage was called with an object. uassert(40229, @@ -108,12 +105,12 @@ public: // Create the pointer, parse the stage, and return. std::unique_ptr<ReplaceRootTransformation> parsedReplaceRoot = stdx::make_unique<ReplaceRootTransformation>(); - parsedReplaceRoot->parse(spec); + parsedReplaceRoot->parse(expCtx, spec); return parsedReplaceRoot; } // Check for valid replaceRoot options, and populate internal state variables. - void parse(const BSONElement& spec) { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONElement& spec) { // We need a VariablesParseState in order to parse the 'newRoot' expression. VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); @@ -124,7 +121,7 @@ public: if (argName == "newRoot") { // Allows for field path, object, and other expressions. - _newRoot = Expression::parseOperand(argument, vps); + _newRoot = Expression::parseOperand(expCtx, argument, vps); } else { uasserted(40230, str::stream() << "unrecognized option to $replaceRoot stage: " << argName @@ -150,7 +147,7 @@ intrusive_ptr<DocumentSource> DocumentSourceReplaceRoot::createFromBson( BSONElement elem, const intrusive_ptr<ExpressionContext>& pExpCtx) { return new DocumentSourceSingleDocumentTransformation( - pExpCtx, ReplaceRootTransformation::create(elem), "$replaceRoot"); + pExpCtx, ReplaceRootTransformation::create(pExpCtx, elem), "$replaceRoot"); } } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_replace_root.h b/src/mongo/db/pipeline/document_source_replace_root.h index f3313ec5f7a..4656a6cdc2d 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.h +++ b/src/mongo/db/pipeline/document_source_replace_root.h @@ -48,6 +48,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceReplaceRoot directly, use createFromBson() + // instead. DocumentSourceReplaceRoot() = default; }; diff --git a/src/mongo/db/pipeline/document_source_sample.cpp b/src/mongo/db/pipeline/document_source_sample.cpp index a436d3e8a53..c7d98db7260 100644 --- a/src/mongo/db/pipeline/document_source_sample.cpp +++ b/src/mongo/db/pipeline/document_source_sample.cpp @@ -87,10 +87,6 @@ Value DocumentSourceSample::serialize(bool explain) const { return Value(DOC(getSourceName() << DOC("size" << _size))); } -void DocumentSourceSample::doInjectExpressionContext() { - _sortStage->injectExpressionContext(pExpCtx); -} - namespace { const BSONObj randSortSpec = BSON("$rand" << BSON("$meta" << "randVal")); diff --git a/src/mongo/db/pipeline/document_source_sample.h b/src/mongo/db/pipeline/document_source_sample.h index ed0eec9c180..28f1dd97b05 100644 --- a/src/mongo/db/pipeline/document_source_sample.h +++ b/src/mongo/db/pipeline/document_source_sample.h @@ -50,8 +50,6 @@ public: return _size; } - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp index f4b2299c9e8..9fc3b5bf105 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp @@ -52,6 +52,7 @@ DocumentSourceSampleFromRandomCursor::DocumentSourceSampleFromRandomCursor( : DocumentSource(pExpCtx), _size(size), _idField(std::move(idField)), + _seenDocs(pExpCtx->getValueComparator().makeUnorderedValueSet()), _nDocsInColl(nDocsInCollection) {} const char* DocumentSourceSampleFromRandomCursor::getSourceName() const { @@ -76,7 +77,7 @@ double smallestFromSampleOfUniform(PseudoRandom* prng, size_t N) { DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNext() { pExpCtx->checkForInterrupt(); - if (_seenDocs->size() >= static_cast<size_t>(_size)) + if (_seenDocs.size() >= static_cast<size_t>(_size)) return GetNextResult::makeEOF(); auto nextResult = getNextNonDuplicateDocument(); @@ -114,7 +115,7 @@ DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNextNonDu << nextInput.getDocument().toString(), !idField.missing()); - if (_seenDocs->insert(std::move(idField)).second) { + if (_seenDocs.insert(std::move(idField)).second) { return nextInput; } LOG(1) << "$sample encountered duplicate document: " @@ -147,10 +148,6 @@ DocumentSource::GetDepsReturn DocumentSourceSampleFromRandomCursor::getDependenc return SEE_NEXT; } -void DocumentSourceSampleFromRandomCursor::doInjectExpressionContext() { - _seenDocs = pExpCtx->getValueComparator().makeUnorderedValueSet(); -} - intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRandomCursor::create( const intrusive_ptr<ExpressionContext>& expCtx, long long size, @@ -158,7 +155,6 @@ intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRand long long nDocsInCollection) { intrusive_ptr<DocumentSourceSampleFromRandomCursor> source( new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection)); - source->injectExpressionContext(expCtx); return source; } } // mongo diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h index 6e8a2ae1d39..0d0ac39ca49 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h @@ -44,8 +44,6 @@ public: Value serialize(bool explain = false) const final; GetDepsReturn getDependencies(DepsTracker* deps) const final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr<DocumentSourceSampleFromRandomCursor> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, long long size, @@ -71,9 +69,8 @@ private: std::string _idField; // Keeps track of the documents that have been returned, since a random cursor is allowed to - // return duplicates. We use boost::optional to defer initialization until the ExpressionContext - // containing the correct comparator is injected. - boost::optional<ValueUnorderedSet> _seenDocs; + // return duplicates. + ValueUnorderedSet _seenDocs; // The approximate number of documents in the collection (includes orphans). const long long _nDocsInColl; diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp index 67dd7338395..35204d272a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp @@ -98,10 +98,6 @@ DocumentSource::GetDepsReturn DocumentSourceSingleDocumentTransformation::getDep return _parsedTransform->addDependencies(deps); } -void DocumentSourceSingleDocumentTransformation::doInjectExpressionContext() { - _parsedTransform->injectExpressionContext(pExpCtx); -} - DocumentSource::GetModPathsReturn DocumentSourceSingleDocumentTransformation::getModifiedPaths() const { return _parsedTransform->getModifiedPaths(); diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.h b/src/mongo/db/pipeline/document_source_single_document_transformation.h index b2db5b0c3ae..1251cb454a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.h +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.h @@ -57,8 +57,6 @@ public: virtual void optimize() = 0; virtual Document serialize(bool explain) const = 0; virtual DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const = 0; - virtual void injectExpressionContext( - const boost::intrusive_ptr<ExpressionContext>& pExpCtx) = 0; virtual GetModPathsReturn getModifiedPaths() const = 0; }; @@ -75,7 +73,6 @@ public: Value serialize(bool explain) const final; Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; DocumentSource::GetDepsReturn getDependencies(DepsTracker* deps) const final; GetModPathsReturn getModifiedPaths() const final; diff --git a/src/mongo/db/pipeline/document_source_skip.cpp b/src/mongo/db/pipeline/document_source_skip.cpp index 27bb6d1d2a6..ff528f0634e 100644 --- a/src/mongo/db/pipeline/document_source_skip.cpp +++ b/src/mongo/db/pipeline/document_source_skip.cpp @@ -96,7 +96,6 @@ Pipeline::SourceContainer::iterator DocumentSourceSkip::doOptimizeAt( intrusive_ptr<DocumentSourceSkip> DocumentSourceSkip::create( const intrusive_ptr<ExpressionContext>& pExpCtx, long long nToSkip) { intrusive_ptr<DocumentSourceSkip> skip(new DocumentSourceSkip(pExpCtx, nToSkip)); - skip->injectExpressionContext(pExpCtx); return skip; } diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp index 4f4f17da4f2..6554a892016 100644 --- a/src/mongo/db/pipeline/document_source_sort.cpp +++ b/src/mongo/db/pipeline/document_source_sort.cpp @@ -113,7 +113,7 @@ long long DocumentSourceSort::getLimit() const { void DocumentSourceSort::addKey(StringData fieldPath, bool ascending) { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - vSortKey.push_back(ExpressionFieldPath::parse("$$ROOT." + fieldPath.toString(), vps)); + vSortKey.push_back(ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + fieldPath.toString(), vps)); vAscending.push_back(ascending); } @@ -176,7 +176,6 @@ intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create( uint64_t maxMemoryUsageBytes) { intrusive_ptr<DocumentSourceSort> pSort(new DocumentSourceSort(pExpCtx)); pSort->_maxMemoryUsageBytes = maxMemoryUsageBytes; - pSort->injectExpressionContext(pExpCtx); pSort->_sort = sortOrder.getOwned(); for (auto&& keyField : sortOrder) { @@ -197,7 +196,7 @@ intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create( VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - pSort->vSortKey.push_back(ExpressionMeta::parse(metaDoc.firstElement(), vps)); + pSort->vSortKey.push_back(ExpressionMeta::parse(pExpCtx, metaDoc.firstElement(), vps)); // If sorting by textScore, sort highest scores first. If sorting by randVal, order // doesn't matter, so just always use descending. diff --git a/src/mongo/db/pipeline/document_source_sort_by_count.h b/src/mongo/db/pipeline/document_source_sort_by_count.h index 2444d048259..69bf8d2c5e0 100644 --- a/src/mongo/db/pipeline/document_source_sort_by_count.h +++ b/src/mongo/db/pipeline/document_source_sort_by_count.h @@ -44,6 +44,8 @@ public: BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); private: + // It is illegal to construct a DocumentSourceSortByCount directly, use createFromBson() + // instead. DocumentSourceSortByCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_tee_consumer.h b/src/mongo/db/pipeline/document_source_tee_consumer.h index 928ac8cba30..bfdbfbda1cb 100644 --- a/src/mongo/db/pipeline/document_source_tee_consumer.h +++ b/src/mongo/db/pipeline/document_source_tee_consumer.h @@ -38,7 +38,7 @@ namespace mongo { class Document; -struct ExpressionContext; +class ExpressionContext; class Value; /** diff --git a/src/mongo/db/pipeline/document_source_unwind.cpp b/src/mongo/db/pipeline/document_source_unwind.cpp index 368b29afbd4..683069597e0 100644 --- a/src/mongo/db/pipeline/document_source_unwind.cpp +++ b/src/mongo/db/pipeline/document_source_unwind.cpp @@ -183,7 +183,6 @@ intrusive_ptr<DocumentSourceUnwind> DocumentSourceUnwind::create( FieldPath(unwindPath), preserveNullAndEmptyArrays, indexPath ? FieldPath(*indexPath) : boost::optional<FieldPath>())); - source->injectExpressionContext(expCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_unwind_test.cpp b/src/mongo/db/pipeline/document_source_unwind_test.cpp index 48df219770a..8dcf1ed9a32 100644 --- a/src/mongo/db/pipeline/document_source_unwind_test.cpp +++ b/src/mongo/db/pipeline/document_source_unwind_test.cpp @@ -42,7 +42,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_source_unwind.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context.h" @@ -70,7 +70,8 @@ public: CheckResultsBase() : _queryServiceContext(stdx::make_unique<QueryTestServiceContext>()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))) {} + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))) {} virtual ~CheckResultsBase() {} @@ -141,7 +142,7 @@ protected: return expectedIndexedResultSetString(); } - intrusive_ptr<ExpressionContext> ctx() const { + intrusive_ptr<ExpressionContextForTest> ctx() const { return _ctx; } @@ -248,7 +249,7 @@ private: unique_ptr<QueryTestServiceContext> _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr<ExpressionContext> _ctx; + intrusive_ptr<ExpressionContextForTest> _ctx; intrusive_ptr<DocumentSourceUnwind> _unwind; }; diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index e9de0eb3472..d9a087e08c1 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -189,17 +189,20 @@ string Expression::removeFieldPrefix(const string& prefixedField) { return string(pPrefixedField + 1); } -intrusive_ptr<Expression> Expression::parseObject(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { if (obj.isEmpty()) { - return ExpressionObject::create({}); + return ExpressionObject::create(expCtx, {}); } if (obj.firstElementFieldName()[0] == '$') { // Assume this is an expression like {$add: [...]}. - return parseExpression(obj, vps); + return parseExpression(expCtx, obj, vps); } - return ExpressionObject::parse(obj, vps); + return ExpressionObject::parse(expCtx, obj, vps); } namespace { @@ -214,7 +217,10 @@ void Expression::registerExpression(string key, Parser parser) { parserMap[key] = parser; } -intrusive_ptr<Expression> Expression::parseExpression(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { uassert(15983, str::stream() << "An object representing an expression must have exactly one " "field: " @@ -227,36 +233,40 @@ intrusive_ptr<Expression> Expression::parseExpression(BSONObj obj, const Variabl uassert(ErrorCodes::InvalidPipelineOperator, str::stream() << "Unrecognized expression '" << opName << "'", op != parserMap.end()); - return op->second(obj.firstElement(), vps); + return op->second(expCtx, obj.firstElement(), vps); } -Expression::ExpressionVector ExpressionNary::parseArguments(BSONElement exprElement, - const VariablesParseState& vps) { +Expression::ExpressionVector ExpressionNary::parseArguments( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { ExpressionVector out; if (exprElement.type() == Array) { BSONForEach(elem, exprElement.Obj()) { - out.push_back(Expression::parseOperand(elem, vps)); + out.push_back(Expression::parseOperand(expCtx, elem, vps)); } } else { // Assume it's an operand that accepts a single argument. - out.push_back(Expression::parseOperand(exprElement, vps)); + out.push_back(Expression::parseOperand(expCtx, exprElement, vps)); } return out; } -intrusive_ptr<Expression> Expression::parseOperand(BSONElement exprElement, - const VariablesParseState& vps) { +intrusive_ptr<Expression> Expression::parseOperand( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { BSONType type = exprElement.type(); if (type == String && exprElement.valuestr()[0] == '$') { /* if we got here, this is a field path expression */ - return ExpressionFieldPath::parse(exprElement.str(), vps); + return ExpressionFieldPath::parse(expCtx, exprElement.str(), vps); } else if (type == Object) { - return Expression::parseObject(exprElement.Obj(), vps); + return Expression::parseObject(expCtx, exprElement.Obj(), vps); } else if (type == Array) { - return ExpressionArray::parse(exprElement, vps); + return ExpressionArray::parse(expCtx, exprElement, vps); } else { - return ExpressionConstant::parse(exprElement, vps); + return ExpressionConstant::parse(expCtx, exprElement, vps); } } @@ -615,13 +625,13 @@ const char* ExpressionCeil::getOpName() const { intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create( const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) { - intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(pExpression)); - pNew->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(expCtx, pExpression)); return pNew; } -ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<Expression>& pTheExpression) - : Expression(), pExpression(pTheExpression) {} +ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext>& expCtx, + const intrusive_ptr<Expression>& pTheExpression) + : Expression(expCtx), pExpression(pTheExpression) {} intrusive_ptr<Expression> ExpressionCoerceToBool::optimize() { /* optimize the operand */ @@ -656,65 +666,68 @@ Value ExpressionCoerceToBool::serialize(bool explain) const { return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain)))); } -void ExpressionCoerceToBool::doInjectExpressionContext() { - // Inject our ExpressionContext into the operand. - pExpression->injectExpressionContext(getExpressionContext()); -} - /* ----------------------- ExpressionCompare --------------------------- */ REGISTER_EXPRESSION(cmp, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::CMP)); REGISTER_EXPRESSION(eq, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::EQ)); REGISTER_EXPRESSION(gt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GT)); REGISTER_EXPRESSION(gte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GTE)); REGISTER_EXPRESSION(lt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LT)); REGISTER_EXPRESSION(lte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LTE)); REGISTER_EXPRESSION(ne, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::NE)); -intrusive_ptr<Expression> ExpressionCompare::parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp op) { - intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(op); - ExpressionVector args = parseArguments(bsonExpr, vps); +intrusive_ptr<Expression> ExpressionCompare::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp op) { + intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, op); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } -ExpressionCompare::ExpressionCompare(CmpOp theCmpOp) : cmpOp(theCmpOp) {} - boost::intrusive_ptr<ExpressionCompare> ExpressionCompare::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp, const boost::intrusive_ptr<Expression>& exprLeft, const boost::intrusive_ptr<Expression>& exprRight) { - boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(cmpOp); + boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, cmpOp); expr->vpOperand = {exprLeft, exprRight}; return expr; } @@ -828,23 +841,26 @@ Value ExpressionCond::evaluateInternal(Variables* vars) const { return vpOperand[idx]->evaluateInternal(vars); } -intrusive_ptr<Expression> ExpressionCond::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionCond::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { if (expr.type() != Object) { - return Base::parse(expr, vps); + return Base::parse(expCtx, expr, vps); } verify(str::equals(expr.fieldName(), "$cond")); - intrusive_ptr<ExpressionCond> ret = new ExpressionCond(); + intrusive_ptr<ExpressionCond> ret = new ExpressionCond(expCtx); ret->vpOperand.resize(3); const BSONObj args = expr.embeddedObject(); BSONForEach(arg, args) { if (str::equals(arg.fieldName(), "if")) { - ret->vpOperand[0] = parseOperand(arg, vps); + ret->vpOperand[0] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "then")) { - ret->vpOperand[1] = parseOperand(arg, vps); + ret->vpOperand[1] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "else")) { - ret->vpOperand[2] = parseOperand(arg, vps); + ret->vpOperand[2] = parseOperand(expCtx, arg, vps); } else { uasserted(17083, str::stream() << "Unrecognized parameter to $cond: " << arg.fieldName()); @@ -865,20 +881,23 @@ const char* ExpressionCond::getOpName() const { /* ---------------------- ExpressionConstant --------------------------- */ -intrusive_ptr<Expression> ExpressionConstant::parse(BSONElement exprElement, - const VariablesParseState& vps) { - return new ExpressionConstant(Value(exprElement)); +intrusive_ptr<Expression> ExpressionConstant::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { + return new ExpressionConstant(expCtx, Value(exprElement)); } intrusive_ptr<ExpressionConstant> ExpressionConstant::create( const intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue) { - intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(pValue)); - pEC->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(expCtx, pValue)); return pEC; } -ExpressionConstant::ExpressionConstant(const Value& pTheValue) : pValue(pTheValue) {} +ExpressionConstant::ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const Value& pTheValue) + : Expression(expCtx), pValue(pTheValue) {} intrusive_ptr<Expression> ExpressionConstant::optimize() { @@ -907,8 +926,10 @@ const char* ExpressionConstant::getOpName() const { /* ---------------------- ExpressionDateToString ----------------------- */ REGISTER_EXPRESSION(dateToString, ExpressionDateToString::parse); -intrusive_ptr<Expression> ExpressionDateToString::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionDateToString::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { verify(str::equals(expr.fieldName(), "$dateToString")); uassert(18629, "$dateToString only supports an object as its argument", expr.type() == Object); @@ -939,11 +960,14 @@ intrusive_ptr<Expression> ExpressionDateToString::parse(BSONElement expr, validateFormat(format); - return new ExpressionDateToString(format, parseOperand(dateElem, vps)); + return new ExpressionDateToString(expCtx, format, parseOperand(expCtx, dateElem, vps)); } -ExpressionDateToString::ExpressionDateToString(const string& format, intrusive_ptr<Expression> date) - : _format(format), _date(date) {} +ExpressionDateToString::ExpressionDateToString( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& format, + intrusive_ptr<Expression> date) + : Expression(expCtx), _format(format), _date(date) {} intrusive_ptr<Expression> ExpressionDateToString::optimize() { _date = _date->optimize(); @@ -1103,10 +1127,6 @@ void ExpressionDateToString::addDependencies(DepsTracker* deps) const { _date->addDependencies(deps); } -void ExpressionDateToString::doInjectExpressionContext() { - _date->injectExpressionContext(getExpressionContext()); -} - /* ---------------------- ExpressionDayOfMonth ------------------------- */ Value ExpressionDayOfMonth::evaluateInternal(Variables* vars) const { @@ -1198,16 +1218,20 @@ const char* ExpressionExp::getOpName() const { /* ---------------------- ExpressionObject --------------------------- */ -ExpressionObject::ExpressionObject(vector<pair<string, intrusive_ptr<Expression>>>&& expressions) - : _expressions(std::move(expressions)) {} +ExpressionObject::ExpressionObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + vector<pair<string, intrusive_ptr<Expression>>>&& expressions) + : Expression(expCtx), _expressions(std::move(expressions)) {} intrusive_ptr<ExpressionObject> ExpressionObject::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<pair<string, intrusive_ptr<Expression>>>&& expressions) { - return new ExpressionObject(std::move(expressions)); + return new ExpressionObject(expCtx, std::move(expressions)); } -intrusive_ptr<ExpressionObject> ExpressionObject::parse(BSONObj obj, - const VariablesParseState& vps) { +intrusive_ptr<ExpressionObject> ExpressionObject::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps) { // Make sure we don't have any duplicate field names. stdx::unordered_set<string> specifiedFields; @@ -1223,10 +1247,10 @@ intrusive_ptr<ExpressionObject> ExpressionObject::parse(BSONObj obj, << obj.toString(), specifiedFields.find(fieldName) == specifiedFields.end()); specifiedFields.insert(fieldName); - expressions.emplace_back(fieldName, parseOperand(elem, vps)); + expressions.emplace_back(fieldName, parseOperand(expCtx, elem, vps)); } - return new ExpressionObject{std::move(expressions)}; + return new ExpressionObject{expCtx, std::move(expressions)}; } intrusive_ptr<Expression> ExpressionObject::optimize() { @@ -1258,22 +1282,19 @@ Value ExpressionObject::serialize(bool explain) const { return outputDoc.freezeToValue(); } -void ExpressionObject::doInjectExpressionContext() { - for (auto&& pair : _expressions) { - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* --------------------- ExpressionFieldPath --------------------------- */ // this is the old deprecated version only used by tests not using variables -intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create(const string& fieldPath) { - return new ExpressionFieldPath("CURRENT." + fieldPath, Variables::ROOT_ID); +intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const string& fieldPath) { + return new ExpressionFieldPath(expCtx, "CURRENT." + fieldPath, Variables::ROOT_ID); } // this is the new version that supports every syntax -intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(const string& raw, - const VariablesParseState& vps) { +intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& raw, + const VariablesParseState& vps) { uassert(16873, str::stream() << "FieldPath '" << raw << "' doesn't start with $", raw.c_str()[0] == '$'); // c_str()[0] is always a valid reference. @@ -1287,16 +1308,18 @@ intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(const string& raw, const StringData fieldPath = rawSD.substr(2); // strip off $$ const StringData varName = fieldPath.substr(0, fieldPath.find('.')); Variables::uassertValidNameForUserRead(varName); - return new ExpressionFieldPath(fieldPath.toString(), vps.getVariable(varName)); + return new ExpressionFieldPath(expCtx, fieldPath.toString(), vps.getVariable(varName)); } else { - return new ExpressionFieldPath("CURRENT." + raw.substr(1), // strip the "$" prefix + return new ExpressionFieldPath(expCtx, + "CURRENT." + raw.substr(1), // strip the "$" prefix vps.getVariable("CURRENT")); } } - -ExpressionFieldPath::ExpressionFieldPath(const string& theFieldPath, Variables::Id variable) - : _fieldPath(theFieldPath), _variable(variable) {} +ExpressionFieldPath::ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& theFieldPath, + Variables::Id variable) + : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {} intrusive_ptr<Expression> ExpressionFieldPath::optimize() { /* nothing can be done for these */ @@ -1384,8 +1407,10 @@ Value ExpressionFieldPath::serialize(bool explain) const { /* ------------------------- ExpressionFilter ----------------------------- */ REGISTER_EXPRESSION(filter, ExpressionFilter::parse); -intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionFilter::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$filter")); uassert(28646, "$filter only supports an object as its argument", expr.type() == Object); @@ -1411,7 +1436,7 @@ intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, uassert(28650, "Missing 'cond' parameter to $filter", !condElem.eoo()); // Parse "input", only has outer variables. - intrusive_ptr<Expression> input = parseOperand(inputElem, vpsIn); + intrusive_ptr<Expression> input = parseOperand(expCtx, inputElem, vpsIn); // Parse "as". VariablesParseState vpsSub(vpsIn); // vpsSub gets our variable, vpsIn doesn't. @@ -1423,16 +1448,19 @@ intrusive_ptr<Expression> ExpressionFilter::parse(BSONElement expr, Variables::Id varId = vpsSub.defineVariable(varName); // Parse "cond", has access to "as" variable. - intrusive_ptr<Expression> cond = parseOperand(condElem, vpsSub); + intrusive_ptr<Expression> cond = parseOperand(expCtx, condElem, vpsSub); - return new ExpressionFilter(std::move(varName), varId, std::move(input), std::move(cond)); + return new ExpressionFilter( + expCtx, std::move(varName), varId, std::move(input), std::move(cond)); } -ExpressionFilter::ExpressionFilter(string varName, +ExpressionFilter::ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx, + string varName, Variables::Id varId, intrusive_ptr<Expression> input, intrusive_ptr<Expression> filter) - : _varName(std::move(varName)), + : Expression(expCtx), + _varName(std::move(varName)), _varId(varId), _input(std::move(input)), _filter(std::move(filter)) {} @@ -1483,11 +1511,6 @@ void ExpressionFilter::addDependencies(DepsTracker* deps) const { _filter->addDependencies(deps); } -void ExpressionFilter::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _filter->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionFloor -------------------------- */ Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { @@ -1512,7 +1535,10 @@ const char* ExpressionFloor::getOpName() const { /* ------------------------- ExpressionLet ----------------------------- */ REGISTER_EXPRESSION(let, ExpressionLet::parse); -intrusive_ptr<Expression> ExpressionLet::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionLet::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$let")); uassert(16874, "$let only supports an object as its argument", expr.type() == Object); @@ -1543,17 +1569,20 @@ intrusive_ptr<Expression> ExpressionLet::parse(BSONElement expr, const Variables Variables::uassertValidNameForUserWrite(varName); Variables::Id id = vpsSub.defineVariable(varName); - vars[id] = NameAndExpression(varName, parseOperand(varElem, vpsIn)); // only has outer vars + vars[id] = NameAndExpression(varName, + parseOperand(expCtx, varElem, vpsIn)); // only has outer vars } // parse "in" - intrusive_ptr<Expression> subExpression = parseOperand(inElem, vpsSub); // has our vars + intrusive_ptr<Expression> subExpression = parseOperand(expCtx, inElem, vpsSub); // has our vars - return new ExpressionLet(vars, subExpression); + return new ExpressionLet(expCtx, vars, subExpression); } -ExpressionLet::ExpressionLet(const VariableMap& vars, intrusive_ptr<Expression> subExpression) - : _variables(vars), _subExpression(subExpression) {} +ExpressionLet::ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const VariableMap& vars, + intrusive_ptr<Expression> subExpression) + : Expression(expCtx), _variables(vars), _subExpression(subExpression) {} intrusive_ptr<Expression> ExpressionLet::optimize() { if (_variables.empty()) { @@ -1603,18 +1632,14 @@ void ExpressionLet::addDependencies(DepsTracker* deps) const { _subExpression->addDependencies(deps); } -void ExpressionLet::doInjectExpressionContext() { - _subExpression->injectExpressionContext(getExpressionContext()); - for (auto&& variable : _variables) { - variable.second.expression->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionMap ----------------------------- */ REGISTER_EXPRESSION(map, ExpressionMap::parse); -intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionMap::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$map")); uassert(16878, "$map only supports an object as its argument", expr.type() == Object); @@ -1641,7 +1666,8 @@ intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const Variables uassert(16882, "Missing 'in' parameter to $map", !inElem.eoo()); // parse "input" - intrusive_ptr<Expression> input = parseOperand(inputElem, vpsIn); // only has outer vars + intrusive_ptr<Expression> input = + parseOperand(expCtx, inputElem, vpsIn); // only has outer vars // parse "as" VariablesParseState vpsSub(vpsIn); // vpsSub gets our vars, vpsIn doesn't. @@ -1653,16 +1679,18 @@ intrusive_ptr<Expression> ExpressionMap::parse(BSONElement expr, const Variables Variables::Id varId = vpsSub.defineVariable(varName); // parse "in" - intrusive_ptr<Expression> in = parseOperand(inElem, vpsSub); // has access to map variable + intrusive_ptr<Expression> in = + parseOperand(expCtx, inElem, vpsSub); // has access to map variable - return new ExpressionMap(varName, varId, input, in); + return new ExpressionMap(expCtx, varName, varId, input, in); } -ExpressionMap::ExpressionMap(const string& varName, +ExpressionMap::ExpressionMap(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const string& varName, Variables::Id varId, intrusive_ptr<Expression> input, intrusive_ptr<Expression> each) - : _varName(varName), _varId(varId), _input(input), _each(each) {} + : Expression(expCtx), _varName(varName), _varId(varId), _input(input), _each(each) {} intrusive_ptr<Expression> ExpressionMap::optimize() { // TODO handle when _input is constant @@ -1711,27 +1739,26 @@ void ExpressionMap::addDependencies(DepsTracker* deps) const { _each->addDependencies(deps); } -void ExpressionMap::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _each->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionMeta ----------------------------- */ REGISTER_EXPRESSION(meta, ExpressionMeta::parse); -intrusive_ptr<Expression> ExpressionMeta::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr<Expression> ExpressionMeta::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { uassert(17307, "$meta only supports string arguments", expr.type() == String); if (expr.valueStringData() == "textScore") { - return new ExpressionMeta(MetaType::TEXT_SCORE); + return new ExpressionMeta(expCtx, MetaType::TEXT_SCORE); } else if (expr.valueStringData() == "randVal") { - return new ExpressionMeta(MetaType::RAND_VAL); + return new ExpressionMeta(expCtx, MetaType::RAND_VAL); } else { uasserted(17308, "Unsupported argument to $meta: " + expr.String()); } } -ExpressionMeta::ExpressionMeta(MetaType metaType) : _metaType(metaType) {} +ExpressionMeta::ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx, + MetaType metaType) + : Expression(expCtx), _metaType(metaType) {} Value ExpressionMeta::serialize(bool explain) const { switch (_metaType) { @@ -2399,12 +2426,6 @@ Value ExpressionNary::serialize(bool explain) const { return Value(DOC(getOpName() << array)); } -void ExpressionNary::doInjectExpressionContext() { - for (auto&& operand : vpOperand) { - operand->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionNot ----------------------------- */ Value ExpressionNot::evaluateInternal(Variables* vars) const { @@ -2491,8 +2512,9 @@ const char* ExpressionOr::getOpName() const { /* ----------------------- ExpressionPow ---------------------------- */ -intrusive_ptr<Expression> ExpressionPow::create(Value base, Value exp) { - intrusive_ptr<ExpressionPow> expr(new ExpressionPow()); +intrusive_ptr<Expression> ExpressionPow::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp) { + intrusive_ptr<ExpressionPow> expr(new ExpressionPow(expCtx)); expr->vpOperand.push_back( ExpressionConstant::create(expr->getExpressionContext(), std::move(base))); expr->vpOperand.push_back( @@ -2723,14 +2745,16 @@ const char* ExpressionRange::getOpName() const { /* ------------------------ ExpressionReduce ------------------------------ */ REGISTER_EXPRESSION(reduce, ExpressionReduce::parse); -intrusive_ptr<Expression> ExpressionReduce::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionReduce::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40075, str::stream() << "$reduce requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce()); + intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce(expCtx)); // vpsSub is used only to parse 'in', which must have access to $$this and $$value. VariablesParseState vpsSub(vps); @@ -2741,11 +2765,11 @@ intrusive_ptr<Expression> ExpressionReduce::parse(BSONElement expr, auto field = elem.fieldNameStringData(); if (field == "input") { - reduce->_input = parseOperand(elem, vps); + reduce->_input = parseOperand(expCtx, elem, vps); } else if (field == "initialValue") { - reduce->_initial = parseOperand(elem, vps); + reduce->_initial = parseOperand(expCtx, elem, vps); } else if (field == "in") { - reduce->_in = parseOperand(elem, vpsSub); + reduce->_in = parseOperand(expCtx, elem, vpsSub); } else { uasserted(40076, str::stream() << "$reduce found an unknown argument: " << field); } @@ -2802,12 +2826,6 @@ Value ExpressionReduce::serialize(bool explain) const { {"in", _in->serialize(explain)}}}}); } -void ExpressionReduce::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _initial->injectExpressionContext(getExpressionContext()); - _in->injectExpressionContext(getExpressionContext()); -} - /* ------------------------ ExpressionReverseArray ------------------------ */ Value ExpressionReverseArray::evaluateInternal(Variables* vars) const { @@ -3031,8 +3049,10 @@ Value ExpressionSetIsSubset::evaluateInternal(Variables* vars) const { */ class ExpressionSetIsSubset::Optimized : public ExpressionSetIsSubset { public: - Optimized(const ValueSet& cachedRhsSet, const ExpressionVector& operands) - : _cachedRhsSet(cachedRhsSet) { + Optimized(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const ValueSet& cachedRhsSet, + const ExpressionVector& operands) + : ExpressionSetIsSubset(expCtx), _cachedRhsSet(cachedRhsSet) { vpOperand = operands; } @@ -3068,9 +3088,10 @@ intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() { << typeName(rhs.getType()), rhs.isArray()); - intrusive_ptr<Expression> optimizedWithConstant(new Optimized( - arrayToSet(rhs, getExpressionContext()->getValueComparator()), vpOperand)); - optimizedWithConstant->injectExpressionContext(getExpressionContext()); + intrusive_ptr<Expression> optimizedWithConstant( + new Optimized(this->getExpressionContext(), + arrayToSet(rhs, getExpressionContext()->getValueComparator()), + vpOperand)); return optimizedWithConstant; } return optimized; @@ -3579,14 +3600,16 @@ Value ExpressionSwitch::evaluateInternal(Variables* vars) const { return _default->evaluateInternal(vars); } -boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, - const VariablesParseState& vps) { +boost::intrusive_ptr<Expression> ExpressionSwitch::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40060, str::stream() << "$switch requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch()); + intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch(expCtx)); for (auto&& elem : expr.Obj()) { auto field = elem.fieldNameStringData(); @@ -3610,9 +3633,9 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, auto branchField = branchElement.fieldNameStringData(); if (branchField == "case") { - branchExpression.first = parseOperand(branchElement, vps); + branchExpression.first = parseOperand(expCtx, branchElement, vps); } else if (branchField == "then") { - branchExpression.second = parseOperand(branchElement, vps); + branchExpression.second = parseOperand(expCtx, branchElement, vps); } else { uasserted(40063, str::stream() << "$switch found an unknown argument to a branch: " @@ -3631,7 +3654,7 @@ boost::intrusive_ptr<Expression> ExpressionSwitch::parse(BSONElement expr, } } else if (field == "default") { // Optional, arbitrary expression. - expression->_default = parseOperand(elem, vps); + expression->_default = parseOperand(expCtx, elem, vps); } else { uasserted(40067, str::stream() << "$switch found an unknown argument: " << field); } @@ -3686,17 +3709,6 @@ Value ExpressionSwitch::serialize(bool explain) const { return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}}); } -void ExpressionSwitch::doInjectExpressionContext() { - if (_default) { - _default->injectExpressionContext(getExpressionContext()); - } - - for (auto&& pair : _branches) { - pair.first->injectExpressionContext(getExpressionContext()); - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionToLower ----------------------------- */ Value ExpressionToLower::evaluateInternal(Variables* vars) const { @@ -3956,13 +3968,16 @@ const char* ExpressionYear::getOpName() const { /* -------------------------- ExpressionZip ------------------------------ */ REGISTER_EXPRESSION(zip, ExpressionZip::parse); -intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr<Expression> ExpressionZip::parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(34460, str::stream() << "$zip only supports an object as an argument, found " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr<ExpressionZip> newZip(new ExpressionZip()); + intrusive_ptr<ExpressionZip> newZip(new ExpressionZip(expCtx)); for (auto&& elem : expr.Obj()) { const auto field = elem.fieldNameStringData(); @@ -3972,7 +3987,7 @@ intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_inputs.push_back(parseOperand(subExpr, vps)); + newZip->_inputs.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "defaults") { uassert(34462, @@ -3980,7 +3995,7 @@ intrusive_ptr<Expression> ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_defaults.push_back(parseOperand(subExpr, vps)); + newZip->_defaults.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "useLongestLength") { uassert(34463, @@ -4123,14 +4138,4 @@ void ExpressionZip::addDependencies(DepsTracker* deps) const { }); } -void ExpressionZip::doInjectExpressionContext() { - for (auto&& expr : _inputs) { - expr->injectExpressionContext(getExpressionContext()); - } - - for (auto&& expr : _defaults) { - expr->injectExpressionContext(getExpressionContext()); - } -} - } // namespace mongo diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 00e771784c2..ef9a3ae3f19 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -172,8 +172,8 @@ private: class Expression : public IntrusiveCounterUnsigned { public: - using Parser = - stdx::function<boost::intrusive_ptr<Expression>(BSONElement, const VariablesParseState&)>; + using Parser = stdx::function<boost::intrusive_ptr<Expression>( + const boost::intrusive_ptr<ExpressionContext>&, BSONElement, const VariablesParseState&)>; virtual ~Expression(){}; @@ -235,8 +235,10 @@ public: * Calls parseExpression() on any sub-document (including possibly the entire document) which * consists of a single field name starting with a '$'. */ - static boost::intrusive_ptr<Expression> parseObject(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONObj which has already been determined to be a functional expression. @@ -244,8 +246,10 @@ public: * Throws an error if 'obj' does not contain exactly one field, or if that field's name does not * match a registered expression name. */ - static boost::intrusive_ptr<Expression> parseExpression(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONElement which is an argument to an Expression. @@ -254,8 +258,10 @@ public: * parseObject(), ExpressionFieldPath::parse(), ExpressionArray::parse(), or * ExpressionConstant::parse() as necessary. */ - static boost::intrusive_ptr<Expression> parseOperand(BSONElement exprElement, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parseOperand( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement exprElement, + const VariablesParseState& vps); /* Produce a field path std::string with the field prefix removed. @@ -282,24 +288,10 @@ public: */ static void registerExpression(std::string key, Parser parser); - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Expression. - * Construction of expressions is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: - typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector; + Expression(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} - /** - * Expressions which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} + typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector; const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { return _expCtx; @@ -344,12 +336,13 @@ public: /// Allow subclasses the opportunity to validate arguments at parse time. virtual void validateArguments(const ExpressionVector& args) const {} - static ExpressionVector parseArguments(BSONElement bsonExpr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static ExpressionVector parseArguments(const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); protected: - ExpressionNary() {} + explicit ExpressionNary(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} ExpressionVector vpOperand; }; @@ -358,19 +351,29 @@ protected: template <typename SubClass> class ExpressionNaryBase : public ExpressionNary { public: - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps) { - boost::intrusive_ptr<ExpressionNaryBase> expr = new SubClass(); - ExpressionVector args = parseArguments(bsonExpr, vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps) { + boost::intrusive_ptr<ExpressionNaryBase> expr = new SubClass(expCtx); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } + +protected: + explicit ExpressionNaryBase(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNary(expCtx) {} }; /// Inherit from this class if your expression takes a variable number of arguments. template <typename SubClass> -class ExpressionVariadic : public ExpressionNaryBase<SubClass> {}; +class ExpressionVariadic : public ExpressionNaryBase<SubClass> { +public: + explicit ExpressionVariadic(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} +}; /** * Inherit from this class if your expression can take a range of arguments, e.g. if it has some @@ -379,6 +382,9 @@ class ExpressionVariadic : public ExpressionNaryBase<SubClass> {}; template <typename SubClass, int MinArgs, int MaxArgs> class ExpressionRangedArity : public ExpressionNaryBase<SubClass> { public: + explicit ExpressionRangedArity(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(28667, mongoutils::str::stream() << "Expression " << this->getOpName() @@ -397,6 +403,9 @@ public: template <typename SubClass, int NArgs> class ExpressionFixedArity : public ExpressionNaryBase<SubClass> { public: + explicit ExpressionFixedArity(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionNaryBase<SubClass>(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(16020, mongoutils::str::stream() << "Expression " << this->getOpName() << " takes exactly " @@ -416,9 +425,11 @@ template <typename Accumulator> class ExpressionFromAccumulator : public ExpressionVariadic<ExpressionFromAccumulator<Accumulator>> { public: + explicit ExpressionFromAccumulator(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionFromAccumulator<Accumulator>>(expCtx) {} + Value evaluateInternal(Variables* vars) const final { - Accumulator accum; - accum.injectExpressionContext(this->getExpressionContext()); + Accumulator accum(this->getExpressionContext()); const size_t n = this->vpOperand.size(); // If a single array arg is given, loop through it passing each member to the accumulator. // If a single, non-array arg is given, pass it directly to the accumulator. @@ -446,15 +457,15 @@ public: if (this->vpOperand.size() == 1) { return false; } - return Accumulator().isAssociative(); + return Accumulator(this->getExpressionContext()).isAssociative(); } bool isCommutative() const final { - return Accumulator().isCommutative(); + return Accumulator(this->getExpressionContext()).isCommutative(); } const char* getOpName() const final { - return Accumulator().getOpName(); + return Accumulator(this->getExpressionContext()).getOpName(); } }; @@ -464,6 +475,9 @@ public: template <typename SubClass> class ExpressionSingleNumericArg : public ExpressionFixedArity<SubClass, 1> { public: + explicit ExpressionSingleNumericArg(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<SubClass, 1>(expCtx) {} + virtual ~ExpressionSingleNumericArg() {} Value evaluateInternal(Variables* vars) const final { @@ -484,6 +498,10 @@ public: class ExpressionAbs final : public ExpressionSingleNumericArg<ExpressionAbs> { +public: + explicit ExpressionAbs(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionAbs>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -491,6 +509,9 @@ class ExpressionAbs final : public ExpressionSingleNumericArg<ExpressionAbs> { class ExpressionAdd final : public ExpressionVariadic<ExpressionAdd> { public: + explicit ExpressionAdd(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionAdd>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -506,6 +527,9 @@ public: class ExpressionAllElementsTrue final : public ExpressionFixedArity<ExpressionAllElementsTrue, 1> { public: + explicit ExpressionAllElementsTrue(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionAllElementsTrue, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -513,6 +537,9 @@ public: class ExpressionAnd final : public ExpressionVariadic<ExpressionAnd> { public: + explicit ExpressionAnd(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionAnd>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -529,6 +556,9 @@ public: class ExpressionAnyElementTrue final : public ExpressionFixedArity<ExpressionAnyElementTrue, 1> { public: + explicit ExpressionAnyElementTrue(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionAnyElementTrue, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -536,7 +566,9 @@ public: class ExpressionArray final : public ExpressionVariadic<ExpressionArray> { public: - // virtuals from ExpressionNary + explicit ExpressionArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionArray>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; Value serialize(bool explain) const final; const char* getOpName() const final; @@ -545,6 +577,9 @@ public: class ExpressionArrayElemAt final : public ExpressionFixedArity<ExpressionArrayElemAt, 2> { public: + explicit ExpressionArrayElemAt(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionArrayElemAt, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -552,6 +587,9 @@ public: class ExpressionCeil final : public ExpressionSingleNumericArg<ExpressionCeil> { public: + explicit ExpressionCeil(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionCeil>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -568,10 +606,9 @@ public: const boost::intrusive_ptr<ExpressionContext>& expCtx, const boost::intrusive_ptr<Expression>& pExpression); - void doInjectExpressionContext() final; - private: - explicit ExpressionCoerceToBool(const boost::intrusive_ptr<Expression>& pExpression); + ExpressionCoerceToBool(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const boost::intrusive_ptr<Expression>& pExpression); boost::intrusive_ptr<Expression> pExpression; }; @@ -593,16 +630,20 @@ public: CMP = 6, // return -1, 0, 1 for a < b, a == b, a > b }; + ExpressionCompare(const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp) + : ExpressionFixedArity<ExpressionCompare, 2>(expCtx), cmpOp(cmpOp) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp cmpOp); - - explicit ExpressionCompare(CmpOp cmpOp); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp cmpOp); static boost::intrusive_ptr<ExpressionCompare> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, CmpOp cmpOp, const boost::intrusive_ptr<Expression>& exprLeft, const boost::intrusive_ptr<Expression>& exprRight); @@ -614,6 +655,9 @@ private: class ExpressionConcat final : public ExpressionVariadic<ExpressionConcat> { public: + explicit ExpressionConcat(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionConcat>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -625,6 +669,9 @@ public: class ExpressionConcatArrays final : public ExpressionVariadic<ExpressionConcatArrays> { public: + explicit ExpressionConcatArrays(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionConcatArrays>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -635,13 +682,19 @@ public: class ExpressionCond final : public ExpressionFixedArity<ExpressionCond, 3> { - typedef ExpressionFixedArity<ExpressionCond, 3> Base; - public: + explicit ExpressionCond(const boost::intrusive_ptr<ExpressionContext>& expCtx) : Base(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); + +private: + typedef ExpressionFixedArity<ExpressionCond, 3> Base; }; @@ -657,8 +710,10 @@ public: static boost::intrusive_ptr<ExpressionConstant> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue); - static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, - const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); /* Get the constant value represented by this Expression. @@ -670,7 +725,7 @@ public: } private: - explicit ExpressionConstant(const Value& pValue); + ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue); Value pValue; }; @@ -682,12 +737,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionDateToString(const std::string& format, // the format string + ExpressionDateToString(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& format, // the format string boost::intrusive_ptr<Expression> date); // the date to format // Will uassert on invalid data @@ -705,6 +762,9 @@ private: class ExpressionDayOfMonth final : public ExpressionFixedArity<ExpressionDayOfMonth, 1> { public: + explicit ExpressionDayOfMonth(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfMonth, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -716,6 +776,9 @@ public: class ExpressionDayOfWeek final : public ExpressionFixedArity<ExpressionDayOfWeek, 1> { public: + explicit ExpressionDayOfWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -728,6 +791,9 @@ public: class ExpressionDayOfYear final : public ExpressionFixedArity<ExpressionDayOfYear, 1> { public: + explicit ExpressionDayOfYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDayOfYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -740,12 +806,19 @@ public: class ExpressionDivide final : public ExpressionFixedArity<ExpressionDivide, 2> { public: + explicit ExpressionDivide(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionDivide, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionExp final : public ExpressionSingleNumericArg<ExpressionExp> { +public: + explicit ExpressionExp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionExp>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -771,18 +844,23 @@ public: indicator @returns the newly created field path expression */ - static boost::intrusive_ptr<ExpressionFieldPath> create(const std::string& fieldPath); + static boost::intrusive_ptr<ExpressionFieldPath> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const std::string& fieldPath); /// Like create(), but works with the raw std::string from the user with the "$" prefixes. - static boost::intrusive_ptr<ExpressionFieldPath> parse(const std::string& raw, - const VariablesParseState& vps); + static boost::intrusive_ptr<ExpressionFieldPath> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& raw, + const VariablesParseState& vps); const FieldPath& getFieldPath() const { return _fieldPath; } private: - ExpressionFieldPath(const std::string& fieldPath, Variables::Id variable); + ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::string& fieldPath, + Variables::Id variable); /* Internal implementation of evaluateInternal(), used recursively. @@ -814,12 +892,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionFilter(std::string varName, + ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::string varName, Variables::Id varId, boost::intrusive_ptr<Expression> input, boost::intrusive_ptr<Expression> filter); @@ -837,6 +917,9 @@ private: class ExpressionFloor final : public ExpressionSingleNumericArg<ExpressionFloor> { public: + explicit ExpressionFloor(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionFloor>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -844,6 +927,9 @@ public: class ExpressionHour final : public ExpressionFixedArity<ExpressionHour, 1> { public: + explicit ExpressionHour(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionHour, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -855,6 +941,9 @@ public: class ExpressionIfNull final : public ExpressionFixedArity<ExpressionIfNull, 2> { public: + explicit ExpressionIfNull(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIfNull, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -862,6 +951,9 @@ public: class ExpressionIn final : public ExpressionFixedArity<ExpressionIn, 2> { public: + explicit ExpressionIn(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIn, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -869,6 +961,9 @@ public: class ExpressionIndexOfArray final : public ExpressionRangedArity<ExpressionIndexOfArray, 2, 4> { public: + explicit ExpressionIndexOfArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfArray, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -876,6 +971,9 @@ public: class ExpressionIndexOfBytes final : public ExpressionRangedArity<ExpressionIndexOfBytes, 2, 4> { public: + explicit ExpressionIndexOfBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfBytes, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -886,6 +984,9 @@ public: */ class ExpressionIndexOfCP final : public ExpressionRangedArity<ExpressionIndexOfCP, 2, 4> { public: + explicit ExpressionIndexOfCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionIndexOfCP, 2, 4>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -898,9 +999,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); struct NameAndExpression { NameAndExpression() {} @@ -914,23 +1016,37 @@ public: typedef std::map<Variables::Id, NameAndExpression> VariableMap; private: - ExpressionLet(const VariableMap& vars, boost::intrusive_ptr<Expression> subExpression); + ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const VariableMap& vars, + boost::intrusive_ptr<Expression> subExpression); VariableMap _variables; boost::intrusive_ptr<Expression> _subExpression; }; class ExpressionLn final : public ExpressionSingleNumericArg<ExpressionLn> { +public: + explicit ExpressionLn(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionLn>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; class ExpressionLog final : public ExpressionFixedArity<ExpressionLog, 2> { +public: + explicit ExpressionLog(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionLog, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionLog10 final : public ExpressionSingleNumericArg<ExpressionLog10> { +public: + explicit ExpressionLog10(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionLog10>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -942,12 +1058,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: ExpressionMap( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const std::string& varName, // name of variable to set Variables::Id varId, // id of variable to set boost::intrusive_ptr<Expression> input, // yields array to iterate @@ -965,7 +1083,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: enum MetaType { @@ -973,13 +1094,16 @@ private: RAND_VAL, }; - ExpressionMeta(MetaType metaType); + ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx, MetaType metaType); MetaType _metaType; }; class ExpressionMillisecond final : public ExpressionFixedArity<ExpressionMillisecond, 1> { public: + explicit ExpressionMillisecond(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMillisecond, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -989,6 +1113,9 @@ public: class ExpressionMinute final : public ExpressionFixedArity<ExpressionMinute, 1> { public: + explicit ExpressionMinute(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMinute, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1000,6 +1127,9 @@ public: class ExpressionMod final : public ExpressionFixedArity<ExpressionMod, 2> { public: + explicit ExpressionMod(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMod, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1007,6 +1137,9 @@ public: class ExpressionMultiply final : public ExpressionVariadic<ExpressionMultiply> { public: + explicit ExpressionMultiply(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionMultiply>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1022,6 +1155,9 @@ public: class ExpressionMonth final : public ExpressionFixedArity<ExpressionMonth, 1> { public: + explicit ExpressionMonth(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionMonth, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1034,6 +1170,9 @@ public: class ExpressionNot final : public ExpressionFixedArity<ExpressionNot, 1> { public: + explicit ExpressionNot(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionNot, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1055,13 +1194,16 @@ public: Value serialize(bool explain) const final; static boost::intrusive_ptr<ExpressionObject> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions); /** * Parses and constructs an ExpressionObject from 'obj'. */ - static boost::intrusive_ptr<ExpressionObject> parse(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr<ExpressionObject> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * This ExpressionObject must outlive the returned vector. @@ -1071,10 +1213,9 @@ public: return _expressions; } - void doInjectExpressionContext() final; - private: ExpressionObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions); // The mapping from field name to expression within this object. This needs to respect the order @@ -1085,6 +1226,9 @@ private: class ExpressionOr final : public ExpressionVariadic<ExpressionOr> { public: + explicit ExpressionOr(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionOr>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1100,7 +1244,11 @@ public: class ExpressionPow final : public ExpressionFixedArity<ExpressionPow, 2> { public: - static boost::intrusive_ptr<Expression> create(Value base, Value exp); + explicit ExpressionPow(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionPow, 2>(expCtx) {} + + static boost::intrusive_ptr<Expression> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp); private: Value evaluateInternal(Variables* vars) const final; @@ -1109,6 +1257,10 @@ private: class ExpressionRange final : public ExpressionRangedArity<ExpressionRange, 2, 3> { +public: + explicit ExpressionRange(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionRange, 2, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1116,15 +1268,18 @@ class ExpressionRange final : public ExpressionRangedArity<ExpressionRange, 2, 3 class ExpressionReduce final : public Expression { public: + explicit ExpressionReduce(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: boost::intrusive_ptr<Expression> _input; boost::intrusive_ptr<Expression> _initial; @@ -1137,6 +1292,9 @@ private: class ExpressionSecond final : public ExpressionFixedArity<ExpressionSecond, 1> { public: + explicit ExpressionSecond(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSecond, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1148,6 +1306,9 @@ public: class ExpressionSetDifference final : public ExpressionFixedArity<ExpressionSetDifference, 2> { public: + explicit ExpressionSetDifference(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSetDifference, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1155,6 +1316,9 @@ public: class ExpressionSetEquals final : public ExpressionVariadic<ExpressionSetEquals> { public: + explicit ExpressionSetEquals(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetEquals>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; void validateArguments(const ExpressionVector& args) const final; @@ -1163,6 +1327,9 @@ public: class ExpressionSetIntersection final : public ExpressionVariadic<ExpressionSetIntersection> { public: + explicit ExpressionSetIntersection(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetIntersection>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1179,6 +1346,9 @@ public: // Not final, inherited from for optimizations. class ExpressionSetIsSubset : public ExpressionFixedArity<ExpressionSetIsSubset, 2> { public: + explicit ExpressionSetIsSubset(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSetIsSubset, 2>(expCtx) {} + boost::intrusive_ptr<Expression> optimize() override; Value evaluateInternal(Variables* vars) const override; const char* getOpName() const final; @@ -1190,7 +1360,9 @@ private: class ExpressionSetUnion final : public ExpressionVariadic<ExpressionSetUnion> { public: - // intrusive_ptr<Expression> optimize() final; + explicit ExpressionSetUnion(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionVariadic<ExpressionSetUnion>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1206,6 +1378,9 @@ public: class ExpressionSize final : public ExpressionFixedArity<ExpressionSize, 1> { public: + explicit ExpressionSize(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSize, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1213,6 +1388,9 @@ public: class ExpressionReverseArray final : public ExpressionFixedArity<ExpressionReverseArray, 1> { public: + explicit ExpressionReverseArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionReverseArray, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1220,6 +1398,9 @@ public: class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3> { public: + explicit ExpressionSlice(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionRangedArity<ExpressionSlice, 2, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1227,6 +1408,9 @@ public: class ExpressionIsArray final : public ExpressionFixedArity<ExpressionIsArray, 1> { public: + explicit ExpressionIsArray(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsArray, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1234,12 +1418,19 @@ public: class ExpressionSplit final : public ExpressionFixedArity<ExpressionSplit, 2> { public: + explicit ExpressionSplit(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSplit, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionSqrt final : public ExpressionSingleNumericArg<ExpressionSqrt> { +public: + explicit ExpressionSqrt(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionSqrt>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1247,6 +1438,9 @@ class ExpressionSqrt final : public ExpressionSingleNumericArg<ExpressionSqrt> { class ExpressionStrcasecmp final : public ExpressionFixedArity<ExpressionStrcasecmp, 2> { public: + explicit ExpressionStrcasecmp(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrcasecmp, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1254,6 +1448,9 @@ public: class ExpressionSubstrBytes : public ExpressionFixedArity<ExpressionSubstrBytes, 3> { public: + explicit ExpressionSubstrBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubstrBytes, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const; }; @@ -1261,18 +1458,29 @@ public: class ExpressionSubstrCP final : public ExpressionFixedArity<ExpressionSubstrCP, 3> { public: + explicit ExpressionSubstrCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubstrCP, 3>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenBytes final : public ExpressionFixedArity<ExpressionStrLenBytes, 1> { +public: + explicit ExpressionStrLenBytes(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrLenBytes, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenCP final : public ExpressionFixedArity<ExpressionStrLenCP, 1> { +public: + explicit ExpressionStrLenCP(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionStrLenCP, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1280,6 +1488,9 @@ class ExpressionStrLenCP final : public ExpressionFixedArity<ExpressionStrLenCP, class ExpressionSubtract final : public ExpressionFixedArity<ExpressionSubtract, 2> { public: + explicit ExpressionSubtract(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionSubtract, 2>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1287,15 +1498,18 @@ public: class ExpressionSwitch final : public Expression { public: + explicit ExpressionSwitch(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: using ExpressionPair = std::pair<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>>; @@ -1307,6 +1521,9 @@ private: class ExpressionToLower final : public ExpressionFixedArity<ExpressionToLower, 1> { public: + explicit ExpressionToLower(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionToLower, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1314,6 +1531,9 @@ public: class ExpressionToUpper final : public ExpressionFixedArity<ExpressionToUpper, 1> { public: + explicit ExpressionToUpper(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionToUpper, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1321,6 +1541,9 @@ public: class ExpressionTrunc final : public ExpressionSingleNumericArg<ExpressionTrunc> { public: + explicit ExpressionTrunc(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionSingleNumericArg<ExpressionTrunc>(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1328,6 +1551,9 @@ public: class ExpressionType final : public ExpressionFixedArity<ExpressionType, 1> { public: + explicit ExpressionType(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionType, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1335,6 +1561,9 @@ public: class ExpressionWeek final : public ExpressionFixedArity<ExpressionWeek, 1> { public: + explicit ExpressionWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1344,6 +1573,9 @@ public: class ExpressionIsoWeekYear final : public ExpressionFixedArity<ExpressionIsoWeekYear, 1> { public: + explicit ExpressionIsoWeekYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoWeekYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1353,6 +1585,9 @@ public: class ExpressionIsoDayOfWeek final : public ExpressionFixedArity<ExpressionIsoDayOfWeek, 1> { public: + explicit ExpressionIsoDayOfWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoDayOfWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1362,6 +1597,9 @@ public: class ExpressionIsoWeek final : public ExpressionFixedArity<ExpressionIsoWeek, 1> { public: + explicit ExpressionIsoWeek(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionIsoWeek, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1371,6 +1609,9 @@ public: class ExpressionYear final : public ExpressionFixedArity<ExpressionYear, 1> { public: + explicit ExpressionYear(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : ExpressionFixedArity<ExpressionYear, 1>(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1383,15 +1624,18 @@ public: class ExpressionZip final : public Expression { public: + explicit ExpressionZip(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr<Expression> optimize() final; - static boost::intrusive_ptr<Expression> parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr<Expression> parse( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: bool _useLongestLength = false; ExpressionVector _inputs; diff --git a/src/mongo/db/pipeline/expression_context.cpp b/src/mongo/db/pipeline/expression_context.cpp index 865bf08ec6c..5bccef201ae 100644 --- a/src/mongo/db/pipeline/expression_context.cpp +++ b/src/mongo/db/pipeline/expression_context.cpp @@ -39,27 +39,27 @@ ExpressionContext::ResolvedNamespace::ResolvedNamespace(NamespaceString ns, std::vector<BSONObj> pipeline) : ns(std::move(ns)), pipeline(std::move(pipeline)) {} -ExpressionContext::ExpressionContext(OperationContext* opCtx, const AggregationRequest& request) +ExpressionContext::ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr<CollatorInterface> collator, + StringMap<ResolvedNamespace> resolvedNamespaces) : isExplain(request.isExplain()), inShard(request.isFromRouter()), extSortAllowed(request.shouldAllowDiskUse()), bypassDocumentValidation(request.shouldBypassDocumentValidation()), ns(request.getNamespaceString()), opCtx(opCtx), - collation(request.getCollation()) { - if (!collation.isEmpty()) { - auto statusWithCollator = - CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation); - uassertStatusOK(statusWithCollator.getStatus()); - setCollator(std::move(statusWithCollator.getValue())); - } -} + collation(request.getCollation()), + _collator(std::move(collator)), + _documentComparator(_collator.get()), + _valueComparator(_collator.get()), + _resolvedNamespaces(std::move(resolvedNamespaces)) {} void ExpressionContext::checkForInterrupt() { // This check could be expensive, at least in relative terms, so don't check every time. - if (--interruptCounter == 0) { + if (--_interruptCounter == 0) { opCtx->checkForInterrupt(); - interruptCounter = kInterruptCheckPeriod; + _interruptCounter = kInterruptCheckPeriod; } } @@ -90,9 +90,9 @@ intrusive_ptr<ExpressionContext> ExpressionContext::copyWith(NamespaceString ns) expCtx->setCollator(_collator->clone()); } - expCtx->resolvedNamespaces = resolvedNamespaces; + expCtx->_resolvedNamespaces = _resolvedNamespaces; - // Note that we intentionally skip copying the value of 'interruptCounter' because 'expCtx' is + // Note that we intentionally skip copying the value of '_interruptCounter' because 'expCtx' is // intended to be used for executing a separate aggregation pipeline. return expCtx; diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h index 67114fa52b0..af6d7d199c9 100644 --- a/src/mongo/db/pipeline/expression_context.h +++ b/src/mongo/db/pipeline/expression_context.h @@ -45,7 +45,7 @@ namespace mongo { -struct ExpressionContext : public IntrusiveCounterUnsigned { +class ExpressionContext : public RefCountable { public: struct ResolvedNamespace { ResolvedNamespace() = default; @@ -55,9 +55,14 @@ public: std::vector<BSONObj> pipeline; }; - ExpressionContext() = default; - - ExpressionContext(OperationContext* opCtx, const AggregationRequest& request); + /** + * Constructs an ExpressionContext to be used for Pipeline parsing and evaluation. + * 'resolvedNamespaces' maps collection names (not full namespaces) to ResolvedNamespaces. + */ + ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr<CollatorInterface> collator, + StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces); /** * Used by a pipeline to check for interrupts so that killOp() works. Throws a UserAssertion if @@ -65,8 +70,6 @@ public: */ void checkForInterrupt(); - void setCollator(std::unique_ptr<CollatorInterface> coll); - const CollatorInterface* getCollator() const { return _collator.get(); } @@ -85,6 +88,16 @@ public: */ boost::intrusive_ptr<ExpressionContext> copyWith(NamespaceString ns) const; + /** + * Returns the ResolvedNamespace corresponding to 'nss'. It is an error to call this method on a + * namespace not involved in the pipeline. + */ + const ResolvedNamespace& getResolvedNamespace(const NamespaceString& nss) const { + auto it = _resolvedNamespaces.find(nss.coll()); + invariant(it != _resolvedNamespaces.end()); + return it->second; + }; + bool isExplain = false; bool inShard = false; bool inRouter = false; @@ -100,19 +113,34 @@ public: // collation. BSONObj collation; - StringMap<ResolvedNamespace> resolvedNamespaces; - +protected: static const int kInterruptCheckPeriod = 128; - int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus -private: - // Collator used to compare elements. 'collator' is initialized from 'collation', except in the - // case where 'collation' is empty and there is a collection default collation. + /** + * Should only be used by 'ExpressionContextForTest'. + */ + ExpressionContext() = default; + + /** + * Sets '_collator' and resets '_documentComparator' and '_valueComparator'. + * + * Use with caution - it is illegal to change the collation once a Pipeline has been parsed with + * this ExpressionContext. + */ + void setCollator(std::unique_ptr<CollatorInterface> collator); + + // Collator used for comparisons. std::unique_ptr<CollatorInterface> _collator; // Used for all comparisons of Document/Value during execution of the aggregation operation. + // Must not be changed after parsing a Pipeline with this ExpressionContext. DocumentComparator _documentComparator; ValueComparator _valueComparator; + + // A map from namespace to the resolved namespace, in case any views are involved. + StringMap<ResolvedNamespace> _resolvedNamespaces; + + int _interruptCounter = kInterruptCheckPeriod; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_context_for_test.h b/src/mongo/db/pipeline/expression_context_for_test.h new file mode 100644 index 00000000000..d093cfe2bfa --- /dev/null +++ b/src/mongo/db/pipeline/expression_context_for_test.h @@ -0,0 +1,64 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/expression_context.h" + +namespace mongo { + +/** + * An ExpressionContext that can have state like the collation and resolved namespace map + * manipulated after construction. In contrast, a regular ExpressionContext requires the collation + * and resolved namespaces to be provided on construction and does not allow them to be subsequently + * mutated. + */ +class ExpressionContextForTest : public ExpressionContext { +public: + ExpressionContextForTest() = default; + + ExpressionContextForTest(OperationContext* txn, const AggregationRequest& request) + : ExpressionContext(txn, request, nullptr, {}) {} + + /** + * Changes the collation used by this ExpressionContext. Must not be changed after parsing a + * Pipeline with this ExpressionContext. + */ + void setCollator(std::unique_ptr<CollatorInterface> collator) { + ExpressionContext::setCollator(std::move(collator)); + } + + /** + * Sets the resolved definition for an involved namespace. + */ + void setResolvedNamespace(const NamespaceString& nss, ResolvedNamespace resolvedNamespace) { + _resolvedNamespaces[nss.coll()] = std::move(resolvedNamespace); + } +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 13e96dfef0c..315a2faeac7 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -36,7 +36,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/dbtests/dbtests.h" #include "mongo/unittest/unittest.h" @@ -62,12 +62,11 @@ static void assertExpectedResults(string expression, initializer_list<pair<vector<Value>, Value>> operations) { for (auto&& op : operations) { try { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const BSONObj obj = BSON(expression << Value(op.first)); - auto expression = Expression::parseExpression(obj, vps); - expression->injectExpressionContext(expCtx); + auto expression = Expression::parseExpression(expCtx, obj, vps); Value result = expression->evaluate(Document()); ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); @@ -190,7 +189,10 @@ public: private: Testable(bool isAssociative, bool isCommutative) - : _isAssociative(isAssociative), _isCommutative(isCommutative) {} + : ExpressionNary( + boost::intrusive_ptr<ExpressionContextForTest>(new ExpressionContextForTest())), + _isAssociative(isAssociative), + _isCommutative(isCommutative) {} bool _isAssociative; bool _isCommutative; }; @@ -224,12 +226,13 @@ protected: } void addOperandArrayToExpr(const intrusive_ptr<Testable>& expr, const BSONArray& operands) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); BSONObjIterator i(operands); while (i.more()) { BSONElement element = i.next(); - expr->addOperand(Expression::parseOperand(element, vps)); + expr->addOperand(Expression::parseOperand(expCtx, element, vps)); } } @@ -244,7 +247,7 @@ TEST_F(ExpressionNaryTest, AddedConstantOperandIsSerialized) { } TEST_F(ExpressionNaryTest, AddedFieldPathOperandIsSerialized) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertContents(_notAssociativeNorCommutative, BSON_ARRAY("$ab.c")); } @@ -258,7 +261,7 @@ TEST_F(ExpressionNaryTest, ValidateConstantExpressionDependency) { } TEST_F(ExpressionNaryTest, ValidateFieldPathExpressionDependency) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("ab.c")); } @@ -267,10 +270,12 @@ TEST_F(ExpressionNaryTest, ValidateObjectExpressionDependency) { << "$x" << "q" << "$r")); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - _notAssociativeNorCommutative->addOperand(Expression::parseObject(specElement.Obj(), vps)); + _notAssociativeNorCommutative->addOperand( + Expression::parseObject(expCtx, specElement.Obj(), vps)); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("r" << "x")); @@ -683,7 +688,8 @@ TEST_F(ExpressionNaryTest, FlattenInnerOperandsOptimizationOnCommutativeAndAssoc class ExpressionCeilTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionCeil(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionCeil(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -741,7 +747,8 @@ TEST_F(ExpressionCeilTest, NullArg) { class ExpressionFloorTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionFloor(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionFloor(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -856,7 +863,8 @@ TEST(ExpressionReverseArrayTest, ReturnsNullWithNullishInput) { class ExpressionTruncTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionTrunc(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + _expr = new ExpressionTrunc(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -917,7 +925,8 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); populateOperands(expression); ASSERT_BSONOBJ_EQ(expectedResult(), toBson(expression->evaluate(Document()))); } @@ -932,7 +941,8 @@ protected: class NullDocument { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(2))); ASSERT_BSONOBJ_EQ(BSON("" << 2), toBson(expression->evaluate(Document()))); } @@ -950,7 +960,8 @@ class NoOperands : public ExpectedResultBase { class String { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value("a"_sd))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -960,7 +971,8 @@ public: class Bool { public: void run() { - intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(true))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -1193,13 +1205,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -1217,13 +1228,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -1515,7 +1525,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr<Expression> nested = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> nested = ExpressionFieldPath::create(nullptr, "a.b"); intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested); DepsTracker dependencies; expression->addDependencies(&dependencies); @@ -1531,7 +1541,7 @@ class AddToBsonObj { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(fromjson("{field:{$and:['$foo']}}"), toBsonObj(expression)); @@ -1548,7 +1558,7 @@ class AddToBsonArray { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(BSON_ARRAY(fromjson("{$and:['$foo']}")), toBsonArray(expression)); @@ -1577,9 +1587,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(constify(expectedOptimized()), expressionToBson(optimized)); } @@ -1610,9 +1619,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); // Check expression spec round trip. ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); // Check evaluation result. @@ -1648,11 +1656,12 @@ class ParseError { public: virtual ~ParseError() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - ASSERT_THROWS(Expression::parseOperand(specElement, vps), UserException); + ASSERT_THROWS(Expression::parseOperand(expCtx, specElement, vps), UserException); } protected: @@ -1855,9 +1864,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true)); } }; @@ -2000,10 +2008,11 @@ public: void run() { BSONObj spec = BSON("IGNORED_FIELD_NAME" << "foo"); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = ExpressionConstant::parse(specElement, vps); + intrusive_ptr<Expression> expression = ExpressionConstant::parse(expCtx, specElement, vps); assertBinaryEqual(BSON("" << "foo"), toBson(expression->evaluate(Document()))); @@ -2140,7 +2149,7 @@ namespace FieldPath { class Invalid { public: void run() { - ASSERT_THROWS(ExpressionFieldPath::create(""), UserException); + ASSERT_THROWS(ExpressionFieldPath::create(nullptr, ""), UserException); } }; @@ -2148,7 +2157,7 @@ public: class Optimize { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); // An attempt to optimize returns the Expression itself. ASSERT_EQUALS(expression, expression->optimize()); } @@ -2158,7 +2167,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); DepsTracker dependencies; expression->addDependencies(&dependencies); ASSERT_EQUALS(1U, dependencies.fields.size()); @@ -2172,7 +2181,7 @@ public: class Missing { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(Document()))); } }; @@ -2181,7 +2190,7 @@ public: class Present { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{'':123}"), toBson(expression->evaluate(fromBson(BSON("a" << 123))))); } @@ -2191,7 +2200,7 @@ public: class NestedBelowNull { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:null}"))))); } @@ -2201,7 +2210,7 @@ public: class NestedBelowUndefined { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:undefined}"))))); } @@ -2211,7 +2220,7 @@ public: class NestedBelowMissing { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{z:1}"))))); } @@ -2221,7 +2230,7 @@ public: class NestedBelowInt { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << 2))))); } }; @@ -2230,7 +2239,7 @@ public: class NestedValue { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << 55), toBson(expression->evaluate(fromBson(BSON("a" << BSON("b" << 55)))))); } @@ -2240,7 +2249,7 @@ public: class NestedBelowEmptyObject { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << BSONObj()))))); } @@ -2250,7 +2259,7 @@ public: class NestedBelowEmptyArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << BSONArray()), toBson(expression->evaluate(fromBson(BSON("a" << BSONArray()))))); } @@ -2260,7 +2269,7 @@ public: class NestedBelowArrayWithNull { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[null]}"))))); } @@ -2270,7 +2279,7 @@ public: class NestedBelowArrayWithUndefined { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[undefined]}"))))); } @@ -2280,7 +2289,7 @@ public: class NestedBelowArrayWithInt { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[1]}"))))); } @@ -2290,7 +2299,7 @@ public: class NestedWithinArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:9}]}"))))); } @@ -2300,7 +2309,7 @@ public: class MultipleArrayValues { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9,20]}"), toBson(expression->evaluate( fromBson(fromjson("{a:[{b:9},null,undefined,{g:4},{b:20},{}]}"))))); @@ -2311,7 +2320,7 @@ public: class ExpandNestedArrays { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(fromjson("{'':[[1,2],3,[4],[[5]],[6,7]]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:[{c:1},{c:2}]}," "{b:{c:3}}," @@ -2325,7 +2334,7 @@ public: class AddToBsonObj { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(BSON("foo" << "$a.b.c"), BSON("foo" << expression->serialize(false))); @@ -2336,7 +2345,7 @@ public: class AddToBsonArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr<Expression> expression = ExpressionFieldPath::create(nullptr, "a.b.c"); BSONArrayBuilder bab; bab << expression->serialize(false); assertBinaryEqual(BSON_ARRAY("$a.b.c"), bab.arr()); @@ -2358,16 +2367,19 @@ Document literal(T&& value) { // TEST(ExpressionObjectParse, ShouldAcceptEmptyObject) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSONObj(), vps); + auto object = ExpressionObject::parse(expCtx, BSONObj(), vps); ASSERT_VALUE_EQ(Value(Document{}), object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << 5 << "b" + auto object = ExpressionObject::parse(expCtx, + BSON("a" << 5 << "b" << "string" << "c" << BSONNULL), @@ -2378,25 +2390,29 @@ TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { } TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("_id" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("_id" << 5), vps); auto expectedResult = Value(Document{{"_id", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a$b" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a$b" << 5), vps); auto expectedResult = Value(Document{{"a$b", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); + auto object = + ExpressionObject::parse(expCtx, fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); auto expectedResult = Value(Document{{"a", Document{{"b", literal(1)}}}, {"c", Document{{"d", Document{{"e", literal(1)}, {"f", literal(1)}}}}}}); @@ -2404,18 +2420,20 @@ TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { } TEST(ExpressionObjectParse, ShouldAcceptArrays) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: [1, 2]}"), vps); + auto object = ExpressionObject::parse(expCtx, fromjson("{a: [1, 2]}"), vps); auto expectedResult = Value(Document{{"a", vector<Value>{Value(literal(1)), Value(literal(2))}}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << BSON("$and" << BSONArray())), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a" << BSON("$and" << BSONArray())), vps); ASSERT_VALUE_EQ(object->serialize(false), Value(Document{{"a", Document{{"$and", BSONArray()}}}})); } @@ -2425,48 +2443,60 @@ TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { // TEST(ExpressionObjectParse, ShouldRejectDottedFieldNames) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("c" << 3 << "a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1 << "c" << 3), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("c" << 3 << "a.b" << 1), vps), + UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1 << "c" << 3), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectDuplicateFieldNames) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << 1), vps), UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), + UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldName) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("$a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON(std::string("a\0b", 3) << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("$a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON(std::string("a\0b", 3) << 1), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldPathAsValue) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" + ASSERT_THROWS(ExpressionObject::parse(expCtx, + BSON("a" << "$field."), vps), UserException); } TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse( - BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), - UserException); + ASSERT_THROWS( + ExpressionObject::parse( + expCtx, BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), + UserException); } // @@ -2474,14 +2504,17 @@ TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { // TEST(ExpressionObjectEvaluate, EmptyObjectShouldEvaluateToEmptyDocument) { - auto object = ExpressionObject::create({}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {}); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { - auto object = ExpressionObject::create({{"a", makeConstant(1)}, {"b", makeConstant(5)}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", makeConstant(1)}, {"b", makeConstant(5)}}); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), @@ -2489,36 +2522,45 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { } TEST(ExpressionObjectEvaluate, OrderOfFieldsInOutputShouldMatchOrderInSpecification) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("a")}, - {"b", ExpressionFieldPath::create("b")}, - {"c", ExpressionFieldPath::create("c")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a")}, + {"b", ExpressionFieldPath::create(expCtx, "b")}, + {"c", ExpressionFieldPath::create(expCtx, "c")}}); ASSERT_VALUE_EQ( Value(Document{{"a", "A"_sd}, {"b", "B"_sd}, {"c", "C"_sd}}), object->evaluate(Document{{"c", "C"_sd}, {"a", "A"_sd}, {"b", "B"_sd}, {"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldRemoveFieldsThatHaveMissingValues) { - auto object = ExpressionObject::create( - {{"a", ExpressionFieldPath::create("a.b")}, {"b", ExpressionFieldPath::create("missing")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a.b")}, + {"b", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document{{"a", 1}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); auto object = ExpressionObject::create( + expCtx, {{"a", ExpressionObject::create( - {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create("_id")}})}}); + expCtx, + {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create(expCtx, "_id")}})}}); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"_sd}}}}), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissing) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("missing")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); - auto objectWithNestedObject = ExpressionObject::create({{"nested", object}}); + auto objectWithNestedObject = ExpressionObject::create(expCtx, {{"nested", object}}); ASSERT_VALUE_EQ(Value(Document{{"nested", Document{}}}), objectWithNestedObject->evaluate(Document())); } @@ -2528,14 +2570,17 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissin // TEST(ExpressionObjectDependencies, ConstantValuesShouldNotBeAddedToDependencies) { - auto object = ExpressionObject::create({{"a", makeConstant(5)}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {{"a", makeConstant(5)}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 0UL); } TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { - auto object = ExpressionObject::create({{"x", ExpressionFieldPath::create("c.d")}}); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"x", ExpressionFieldPath::create(expCtx, "c.d")}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 1UL); @@ -2548,11 +2593,12 @@ TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { TEST(ExpressionObjectOptimizations, OptimizingAnObjectShouldOptimizeSubExpressions) { // Build up the object {a: {$add: [1, 2]}}. + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); auto addExpression = - ExpressionAdd::parse(BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); - auto object = ExpressionObject::create({{"a", addExpression}}); + ExpressionAdd::parse(expCtx, BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); + auto object = ExpressionObject::create(expCtx, {{"a", addExpression}}); ASSERT_EQ(object->getChildExpressions().size(), 1UL); auto optimized = object->optimize(); @@ -2575,11 +2621,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -2597,11 +2644,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -2876,10 +2924,11 @@ namespace Object { * Parses the object given by 'specification', with the options given by 'parseContextOptions'. */ boost::intrusive_ptr<Expression> parseObject(BSONObj specification) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseObject(specification, vps); + return Expression::parseObject(expCtx, specification, vps); }; TEST(ParseObject, ShouldAcceptEmptyObject) { @@ -2910,9 +2959,10 @@ using mongo::Expression; * Parses an expression from the given BSON specification. */ boost::intrusive_ptr<Expression> parseExpression(BSONObj specification) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseExpression(specification, vps); + return Expression::parseExpression(expCtx, specification, vps); } TEST(ParseExpression, ShouldRecognizeConstExpression) { @@ -3016,10 +3066,11 @@ using mongo::Expression; * case the field name would be the name of the expression. */ intrusive_ptr<Expression> parseOperand(BSONObj specification) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONElement specElement = specification.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseOperand(specElement, vps); + return Expression::parseOperand(expCtx, specElement, vps); } TEST(ParseOperand, ShouldRecognizeFieldPath) { @@ -3081,7 +3132,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3092,8 +3143,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr<Expression> expr = + Expression::parseExpression(expCtx, obj, vps); Value result = expr->evaluate(Document()); if (result.getType() == Array) { result = sortSet(result); @@ -3121,8 +3172,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr<Expression> expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); @@ -3399,13 +3449,12 @@ private: return BSON("$strcasecmp" << BSON_ARRAY(b() << a())); } void assertResult(int expectedResult, const BSONObj& spec) { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult), toBson(expression->evaluate(Document()))); } @@ -3527,13 +3576,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3635,22 +3683,24 @@ class DropEndingNull : public ExpectedResultBase { namespace SubstrCP { TEST(ExpressionSubstrCPTest, DoesThrowWithBadContinuationByte) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto continuationByte = "\x80\x00"_sd; const auto expr = Expression::parseExpression( - BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); + expCtx, BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } TEST(ExpressionSubstrCPTest, DoesThrowWithInvalidLeadingByte) { + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto leadingByte = "\xFF\x00"_sd; - const auto expr = - Expression::parseExpression(BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); + const auto expr = Expression::parseExpression( + expCtx, BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } @@ -3792,13 +3842,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3851,13 +3900,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr<Expression> expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3909,7 +3957,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); + intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3920,8 +3968,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr<Expression> expr = + Expression::parseExpression(expCtx, obj, vps); const Value result = expr->evaluate(Document()); if (ValueComparator().evaluate(result != expected)) { string errMsg = str::stream() @@ -3946,8 +3994,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr<Expression> expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); diff --git a/src/mongo/db/pipeline/granularity_rounder.cpp b/src/mongo/db/pipeline/granularity_rounder.cpp index 0941c8b36fa..6071c6644db 100644 --- a/src/mongo/db/pipeline/granularity_rounder.cpp +++ b/src/mongo/db/pipeline/granularity_rounder.cpp @@ -50,11 +50,11 @@ void GranularityRounder::registerGranularityRounder(StringData name, Rounder rou } boost::intrusive_ptr<GranularityRounder> GranularityRounder::getGranularityRounder( - StringData granularity) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData granularity) { auto it = rounderMap.find(granularity); uassert(40257, str::stream() << "Unknown rounding granularity '" << granularity << "'", it != rounderMap.end()); - return it->second(); + return it->second(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder.h b/src/mongo/db/pipeline/granularity_rounder.h index 862c604dcc0..24a8f53de97 100644 --- a/src/mongo/db/pipeline/granularity_rounder.h +++ b/src/mongo/db/pipeline/granularity_rounder.h @@ -32,6 +32,7 @@ #include "mongo/base/init.h" #include "mongo/db/jsobj.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" #include "mongo/util/intrusive_counter.h" @@ -73,7 +74,8 @@ namespace mongo { */ class GranularityRounder : public RefCountable { public: - using Rounder = stdx::function<boost::intrusive_ptr<GranularityRounder>()>; + using Rounder = stdx::function<boost::intrusive_ptr<GranularityRounder>( + const boost::intrusive_ptr<ExpressionContext>&)>; /** * Registers a GranularityRounder with a parsing function so that when a granularity @@ -89,7 +91,8 @@ public: * Retrieves the GranularityRounder for the granularity given by 'granularity', and raises an * error if there is no such granularity registered. */ - static boost::intrusive_ptr<GranularityRounder> getGranularityRounder(StringData granularity); + static boost::intrusive_ptr<GranularityRounder> getGranularityRounder( + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData granularity); /** * Rounds up 'value' to the first value greater than 'value' in the granularity series. If @@ -109,6 +112,16 @@ public: * Returns the name of the granularity series that the GranularityRounder is using for rounding. */ virtual std::string getName() = 0; + +protected: + GranularityRounder(const boost::intrusive_ptr<ExpressionContext>& expCtx) : _expCtx(expCtx) {} + + ExpressionContext* getExpCtx() { + return _expCtx.get(); + } + +private: + boost::intrusive_ptr<ExpressionContext> _expCtx; }; /** @@ -124,8 +137,10 @@ public: * 'baseSeries'. This method requires that baseSeries has at least 2 numbers and is in sorted * order. */ - static boost::intrusive_ptr<GranularityRounder> create(const std::vector<double> baseSeries, - std::string name); + static boost::intrusive_ptr<GranularityRounder> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const std::vector<double> baseSeries, + std::string name); Value roundUp(Value value); Value roundDown(Value value); @@ -138,7 +153,9 @@ public: const std::vector<double> getSeries() const; private: - GranularityRounderPreferredNumbers(std::vector<double> baseSeries, std::string name); + GranularityRounderPreferredNumbers(const boost::intrusive_ptr<ExpressionContext>& expCtx, + std::vector<double> baseSeries, + std::string name); // '_baseSeries' is the preferred number series that is used for rounding. A preferred numbers // series is infinite, but we represent it with a finite vector of numbers. When rounding, we @@ -153,13 +170,16 @@ private: */ class GranularityRounderPowersOfTwo final : public GranularityRounder { public: - static boost::intrusive_ptr<GranularityRounder> create(); + static boost::intrusive_ptr<GranularityRounder> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx); Value roundUp(Value value); Value roundDown(Value value); std::string getName(); private: - GranularityRounderPowersOfTwo() = default; + GranularityRounderPowersOfTwo(const boost::intrusive_ptr<ExpressionContext>& expCtx) + : GranularityRounder(expCtx) {} + std::string _name = "POWERSOF2"; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp index 56d3ae02267..078936e8145 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp @@ -40,8 +40,9 @@ using std::string; REGISTER_GRANULARITY_ROUNDER(POWERSOF2, GranularityRounderPowersOfTwo::create); -intrusive_ptr<GranularityRounder> GranularityRounderPowersOfTwo::create() { - return new GranularityRounderPowersOfTwo(); +intrusive_ptr<GranularityRounder> GranularityRounderPowersOfTwo::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return new GranularityRounderPowersOfTwo(expCtx); } namespace { @@ -80,7 +81,7 @@ Value GranularityRounderPowersOfTwo::roundUp(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } Value GranularityRounderPowersOfTwo::roundDown(Value value) { @@ -113,7 +114,7 @@ Value GranularityRounderPowersOfTwo::roundDown(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } string GranularityRounderPowersOfTwo::getName() { diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp index adb7046c0c1..eefcdb749e5 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp @@ -31,6 +31,7 @@ #include "mongo/db/pipeline/granularity_rounder.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -50,7 +51,8 @@ void testEquals(Value actual, Value expected, double delta = DELTA) { } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0.5)), Value(1)); testEquals(rounder->roundUp(Value(1)), Value(2)); @@ -64,7 +66,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 4611686018427387905; // 2^62 + 1 double output = 9223372036854775808.0; // 2^63 @@ -78,7 +81,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); int input = 1073741824; // 2^30 long long output = 2147483648; // 2^31 @@ -92,7 +96,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubleIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); double input = 9223372036854775808.0; // 2^63 long long output = 4611686018427387904; // 2^62 @@ -106,7 +111,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubl } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumberLongIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 2147483648; // 2^31 int output = 1073741824; // 2^30 @@ -120,7 +126,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumber } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.12); Decimal128 output = Decimal128(0.125); @@ -134,7 +141,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpN } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDownNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.13); Decimal128 output = Decimal128(0.125); @@ -148,7 +156,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDow } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(3)), Value(4)); testEquals(rounder->roundUp(Value(5)), Value(8)); @@ -168,7 +177,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOf } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(16)), Value(8)); testEquals(rounder->roundDown(Value(8)), Value(4)); @@ -183,7 +193,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfT } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(10)), Value(8)); testEquals(rounder->roundDown(Value(9)), Value(8)); @@ -201,14 +212,16 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPower } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundZeroToZero) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0)), Value(0)); testEquals(rounder->roundDown(Value(0)), Value(0)); } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -217,7 +230,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value nan = Value(std::nan("NaN")); ASSERT_THROWS_CODE(rounder->roundUp(nan), UserException, 40266); @@ -232,7 +246,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNegativeNumber) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value negativeNumber = Value(-1); ASSERT_THROWS_CODE(rounder->roundUp(negativeNumber), UserException, 40267); diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp index 81a845bed6e..5b62d435a14 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp @@ -95,56 +95,57 @@ const vector<double> e192Series{ } // namespace // Register the GranularityRounders for the Renard number series. -REGISTER_GRANULARITY_ROUNDER(R5, []() { - return GranularityRounderPreferredNumbers::create(r5Series, "R5"); +REGISTER_GRANULARITY_ROUNDER(R5, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r5Series, "R5"); }); -REGISTER_GRANULARITY_ROUNDER(R10, []() { - return GranularityRounderPreferredNumbers::create(r10Series, "R10"); +REGISTER_GRANULARITY_ROUNDER(R10, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r10Series, "R10"); }); -REGISTER_GRANULARITY_ROUNDER(R20, []() { - return GranularityRounderPreferredNumbers::create(r20Series, "R20"); +REGISTER_GRANULARITY_ROUNDER(R20, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r20Series, "R20"); }); -REGISTER_GRANULARITY_ROUNDER(R40, []() { - return GranularityRounderPreferredNumbers::create(r40Series, "R40"); +REGISTER_GRANULARITY_ROUNDER(R40, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r40Series, "R40"); }); -REGISTER_GRANULARITY_ROUNDER(R80, []() { - return GranularityRounderPreferredNumbers::create(r80Series, "R80"); +REGISTER_GRANULARITY_ROUNDER(R80, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r80Series, "R80"); }); -REGISTER_GRANULARITY_ROUNDER_GENERAL("1-2-5", 1_2_5, []() { - return GranularityRounderPreferredNumbers::create(series125, "1-2-5"); -}); +REGISTER_GRANULARITY_ROUNDER_GENERAL( + "1-2-5", 1_2_5, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, series125, "1-2-5"); + }); // Register the GranularityRounders for the E series. -REGISTER_GRANULARITY_ROUNDER(E6, []() { - return GranularityRounderPreferredNumbers::create(e6Series, "E6"); +REGISTER_GRANULARITY_ROUNDER(E6, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e6Series, "E6"); }); -REGISTER_GRANULARITY_ROUNDER(E12, []() { - return GranularityRounderPreferredNumbers::create(e12Series, "E12"); +REGISTER_GRANULARITY_ROUNDER(E12, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e12Series, "E12"); }); -REGISTER_GRANULARITY_ROUNDER(E24, []() { - return GranularityRounderPreferredNumbers::create(e24Series, "E24"); +REGISTER_GRANULARITY_ROUNDER(E24, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e24Series, "E24"); }); -REGISTER_GRANULARITY_ROUNDER(E48, []() { - return GranularityRounderPreferredNumbers::create(e48Series, "E48"); +REGISTER_GRANULARITY_ROUNDER(E48, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e48Series, "E48"); }); -REGISTER_GRANULARITY_ROUNDER(E96, []() { - return GranularityRounderPreferredNumbers::create(e96Series, "E96"); +REGISTER_GRANULARITY_ROUNDER(E96, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e96Series, "E96"); }); -REGISTER_GRANULARITY_ROUNDER(E192, []() { - return GranularityRounderPreferredNumbers::create(e192Series, "E192"); +REGISTER_GRANULARITY_ROUNDER(E192, [](const boost::intrusive_ptr<ExpressionContext>& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e192Series, "E192"); }); -GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers(vector<double> baseSeries, - string name) - : _baseSeries(baseSeries), _name(name) { +GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers( + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<double> baseSeries, string name) + : GranularityRounder(expCtx), _baseSeries(baseSeries), _name(name) { invariant(_baseSeries.size() > 1); invariant(std::is_sorted(_baseSeries.begin(), _baseSeries.end())); } intrusive_ptr<GranularityRounder> GranularityRounderPreferredNumbers::create( - vector<double> baseSeries, string name) { - return new GranularityRounderPreferredNumbers(baseSeries, name); + const boost::intrusive_ptr<ExpressionContext>& expCtx, vector<double> baseSeries, string name) { + return new GranularityRounderPreferredNumbers(expCtx, baseSeries, name); } namespace { diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp index 17d10d07977..89df39621ad 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp @@ -32,6 +32,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/util/assert_util.h" namespace mongo { @@ -461,7 +462,8 @@ void testSeriesWrappingAroundDecimal(intrusive_ptr<GranularityRounder> rounder) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpInSeries(rounder); testRoundingUpInSeriesDecimal(rounder); @@ -471,7 +473,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownNumberInSeriesToPreviousNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownInSeries(rounder); testRoundingDownInSeriesDecimal(rounder); @@ -480,7 +483,8 @@ TEST(GranularityRounderPreferredNumbersTest, TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpBetweenSeries(rounder); testRoundingUpBetweenSeriesDecimal(rounder); @@ -489,7 +493,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownBetweenSeries(rounder); testRoundingDownBetweenSeriesDecimal(rounder); @@ -498,7 +503,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeries TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testSeriesWrappingAround(rounder); testSeriesWrappingAroundDecimal(rounder); @@ -507,7 +513,8 @@ TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder rounds zero to zero. testEquals(rounder->roundUp(Value(0)), Value(0)); @@ -520,7 +527,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValues) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -531,7 +539,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValue TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding NaN. Value nan = Value(std::nan("NaN")); @@ -549,7 +558,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNegativeNumber) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a negative number. Value negativeNumber = Value(-1); diff --git a/src/mongo/db/pipeline/lookup_set_cache.h b/src/mongo/db/pipeline/lookup_set_cache.h index d482b267ee8..1b7adbd00a6 100644 --- a/src/mongo/db/pipeline/lookup_set_cache.h +++ b/src/mongo/db/pipeline/lookup_set_cache.h @@ -81,9 +81,12 @@ public: * ValueComparator. This requires instantiating the multi_index_container with comparison and * hasher functions obtained from the comparator. */ - LookupSetCache(ValueComparator valueComparator) - : _valueComparator(std::move(valueComparator)), - _container(makeIndexedContainer(_valueComparator)) {} + explicit LookupSetCache(const ValueComparator& comparator) + : _container(boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), + boost::make_tuple(0, + member<Cached, Value, &Cached::first>(), + comparator.getHasher(), + comparator.getEqualTo()))) {} /** * Insert "value" into the set with key "key". If "key" is already present in the cache, move it @@ -186,29 +189,7 @@ public: return boost::none; } - /** - * Binds the cache to a new comparator that should be used to make all subsequent Value - * comparisons. - * - * TODO SERVER-25535: Remove this method. - */ - void setValueComparator(ValueComparator valueComparator) { - _valueComparator = std::move(valueComparator); - _container = makeIndexedContainer(_valueComparator); - } - private: - IndexedContainer makeIndexedContainer(const ValueComparator& valueComparator) const { - return IndexedContainer( - boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), - boost::make_tuple(0, - member<Cached, Value, &Cached::first>(), - valueComparator.getHasher(), - valueComparator.getEqualTo()))); - } - - ValueComparator _valueComparator; - IndexedContainer _container; size_t _memoryUsage = 0; diff --git a/src/mongo/db/pipeline/lookup_set_cache_test.cpp b/src/mongo/db/pipeline/lookup_set_cache_test.cpp index ce7503d4c6e..54a8b990ae7 100644 --- a/src/mongo/db/pipeline/lookup_set_cache_test.cpp +++ b/src/mongo/db/pipeline/lookup_set_cache_test.cpp @@ -52,9 +52,10 @@ BSONObj intToObj(int value) { return BSON("n" << value); } +const ValueComparator defaultComparator{nullptr}; + TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(1)); cache.insert(Value(0), intToObj(2)); cache.insert(Value(0), intToObj(3)); @@ -70,8 +71,7 @@ TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { } TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -90,8 +90,7 @@ TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { } TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -106,8 +105,7 @@ TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { } TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -120,8 +118,7 @@ TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { } TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -134,8 +131,7 @@ TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { } TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); for (int i = 0; i < 5; i++) { for (int j = 0; j < 5; j++) { @@ -173,7 +169,8 @@ TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { TEST(LookupSetCacheTest, CacheKeysRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), intToObj(1)); cache.insert(Value("FOO"_sd), intToObj(2)); @@ -196,7 +193,8 @@ TEST(LookupSetCacheTest, CacheKeysRespectCollation) { // foreign collection. TEST(LookupSetCacheTest, CachedValuesDontRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), BSON("foo" diff --git a/src/mongo/db/pipeline/parsed_add_fields.cpp b/src/mongo/db/pipeline/parsed_add_fields.cpp index 2bca98f21bd..ae78664a8a5 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields.cpp @@ -38,7 +38,8 @@ namespace mongo { namespace parsed_aggregation_projection { -std::unique_ptr<ParsedAddFields> ParsedAddFields::create(const BSONObj& spec) { +std::unique_ptr<ParsedAddFields> ParsedAddFields::create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) { // Verify that we don't have conflicting field paths, etc. Status status = ProjectionSpecValidator::validate(spec); if (!status.isOK()) { @@ -48,17 +49,19 @@ std::unique_ptr<ParsedAddFields> ParsedAddFields::create(const BSONObj& spec) { std::unique_ptr<ParsedAddFields> parsedAddFields = stdx::make_unique<ParsedAddFields>(); // Actually parse the specification. - parsedAddFields->parse(spec); + parsedAddFields->parse(expCtx, spec); return parsedAddFields; } -void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& variablesParseState) { +void ParsedAddFields::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState) { for (auto elem : spec) { auto fieldName = elem.fieldNameStringData(); if (elem.type() == BSONType::Object) { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. } else { // The field name might be a dotted path. If so, we need to keep adding children @@ -72,12 +75,12 @@ void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& vari // It is illegal to construct an empty FieldPath, so the above loop ends one // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -96,7 +99,8 @@ Document ParsedAddFields::applyProjection(Document inputDoc, Variables* vars) co return output.freeze(); } -bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, +bool ParsedAddFields::parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { if (objSpec.firstElementFieldName()[0] == '$') { @@ -104,13 +108,14 @@ bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedAddFields::parseSubObject(const BSONObj& subObj, +void ParsedAddFields::parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node) { for (auto&& elem : subObj) { @@ -123,17 +128,18 @@ void ParsedAddFields::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (!parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { // It was a nested subobject auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_add_fields.h b/src/mongo/db/pipeline/parsed_add_fields.h index 58b8445e727..982fe6a3017 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.h +++ b/src/mongo/db/pipeline/parsed_add_fields.h @@ -58,7 +58,8 @@ public: * Verifies that there are no conflicting paths in the specification. * Overrides the ParsedAggregationProjection's create method. */ - static std::unique_ptr<ParsedAddFields> create(const BSONObj& spec); + static std::unique_ptr<ParsedAddFields> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec); ProjectionType getType() const final { return ProjectionType::kComputed; @@ -67,10 +68,10 @@ public: /** * Parses the addFields specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique<Variables>(idGenerator.getIdCount()); } @@ -87,10 +88,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::SEE_NEXT; @@ -124,7 +121,9 @@ private: /** * Parses 'spec' to determine which fields to add. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -134,7 +133,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -142,7 +142,8 @@ private: * Traverses 'subObj' and parses each field. Adds any computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_add_fields_test.cpp b/src/mongo/db/pipeline/parsed_add_fields_test.cpp index 557e2450dce..d7b68afca1d 100644 --- a/src/mongo/db/pipeline/parsed_add_fields_test.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -47,69 +48,78 @@ namespace { using std::vector; // These ParsedAddFields spec tests are a subset of the ParsedAggregationProjection creation tests. -// ParsedAddField should behave the same way, but does not use the same creation, so we include +// ParsedAddFields should behave the same way, but does not use the same creation, so we include // an abbreviation of the same tests here. // Verify that ParsedAddFields rejects specifications with conflicting field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithConflictingFieldPaths) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // These specs contain the same exact path. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b" << 1 << "b" << 2))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << 3 << "_id" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b" << 1 << "b" << 2))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << 3 << "_id" << true)), + UserException); // These specs contain overlapping paths. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a.b" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a.b.c" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << true << "_id.x" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a.b" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a.b.c" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << true << "_id.x" << true)), + UserException); } // Verify that ParsedAddFields rejects specifications that contain invalid field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidFieldPath) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Dotted subfields are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); // The user cannot start a field with $. - ASSERT_THROWS(ParsedAddFields::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("c.$d" << true)), UserException); // Empty field names should throw an error. - ASSERT_THROWS(ParsedAddFields::create(BSON("" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << BSON("a" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON(".a" << true)), UserException); } // Verify that ParsedAddFields rejects specifications that contain empty objects or invalid // expressions. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidObjectsOrExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Invalid expressions should be rejected. - ASSERT_THROWS( - ParsedAddFields::create(BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), - UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("$gt" << BSON("bad" + ASSERT_THROWS(ParsedAddFields::create( + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); ASSERT_THROWS(ParsedAddFields::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); // Empty specifications are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSONObj()), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSONObj()), UserException); // Empty nested objects are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSONObj())), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSONObj())), UserException); } TEST(ParsedAddFields, DoesNotErrorOnTwoNestedFields) { - ParsedAddFields::create(BSON("a.b" << true << "a.c" << true)); - ParsedAddFields::create(BSON("a.b" << true << "a" << BSON("c" << true))); + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // Verify that replaced fields are not included as dependencies. TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -121,8 +131,9 @@ TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { // Verify that adding nested fields keeps the top-level field as a dependency. TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" << true)); + addition.parse(expCtx, BSON("x.y" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -135,8 +146,10 @@ TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFie // Verify that fields that an expression depends on are added to the dependencies. TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" + addition.parse(expCtx, + BSON("x.y" << "$z" << "a" << "$b")); @@ -155,8 +168,9 @@ TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { // Verify that the serialization produces the correct output: converting numbers and literals to // their corresponding $const form. TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + addition.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); auto expectedSerialization = Document( fromjson("{a: {$add: [\"$a\", {$const: 2}]}, b: {d: {$const: 3}}, x: {y: {$const: 4}}}")); @@ -168,8 +182,9 @@ TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { // Verify that serialize treats the _id field as any other field: including when explicity included. TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" << false)); + addition.parse(expCtx, BSON("_id" << false)); // Adds explicit "_id" setting field, serializes expressions. auto expectedSerialization = Document(fromjson("{_id: {$const: false}}")); @@ -184,8 +199,9 @@ TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { // yet they derive from the same parent class. If the parent class were to change, this test would // fail. TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); // Does not implicitly include "_id" field. auto expectedSerialization = Document(fromjson("{a: {$const: true}}")); @@ -197,8 +213,9 @@ TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { // Verify that the $addFields stage optimizes expressions into simpler forms when possible. TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"$const", 3}}}}; @@ -209,8 +226,9 @@ TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { // Verify that the $addFields stage optimizes expressions even when they are nested. TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"b", Document{{"$const", 3}}}}}}; @@ -225,8 +243,9 @@ TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { // Verify that a new field is added to the end of the document. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // There are no fields in the document. auto result = addition.applyProjection(Document{}); @@ -241,8 +260,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { // Verify that an existing field is replaced and stays in the same order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // Specified field is the only field in the document, and is replaced. auto result = addition.applyProjection(Document{{"c", 1}}); @@ -257,8 +277,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { // Verify that replacing multiple fields preserves the original field order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFieldOrder) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("second" + addition.parse(expCtx, + BSON("second" << "SECOND" << "first" << "FIRST")); @@ -269,8 +291,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFie // Verify that adding multiple fields adds the fields in the order specified. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecified) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "secondComputed" << "SECOND")); @@ -286,8 +310,10 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecif // Verify that both adding and replacing fields at the same time follows the same rules as doing // each independently. TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRulesAsSeparately) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "second" << "SECOND")); @@ -300,8 +326,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRules // Verify that _id is included just like a regular field, in whatever order it appears in the // input document, when adding new fields. TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("newField" + addition.parse(expCtx, + BSON("newField" << "computedVal")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"newField", "computedVal"_sd}}; @@ -314,8 +342,10 @@ TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) // Verify that replacing or adding _id works just like any other field. TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" + addition.parse(expCtx, + BSON("_id" << "newId")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "newId"_sd}, {"a", 1}}; @@ -336,8 +366,9 @@ TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { // Verify that adding a dotted field keeps the other fields in the subdocument. TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedFieldToSubDoc) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = addition.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -362,8 +393,9 @@ TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedF // Verify that creating a dotted field creates the subdocument structure necessary. TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("sub.target" << true)); + addition.parse(expCtx, BSON("sub.target" << true)); // Should add the path if it doesn't exist. auto result = addition.applyProjection(Document{}); @@ -379,8 +411,9 @@ TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) // Verify that adding a dotted value to an array field sets the field in every element of the array. // SERVER-25200: make this agree with $set. TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -404,8 +437,10 @@ TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { // Verify that creation of the subdocument structure works for many layers of nesting. TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedField) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b.c.d" + addition.parse(expCtx, + BSON("a.b.c.d" << "computedVal")); // Should add the path if it doesn't exist. @@ -421,8 +456,10 @@ TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedFiel // Verify that _id is not special: we can add subfields to it as well. TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id.X" << true << "_id.Z" + addition.parse(expCtx, + BSON("_id.X" << true << "_id.Z" << "NEW")); auto result = addition.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", true}, {"Y", 2}, {"Z", "NEW"_sd}}}}; @@ -432,10 +469,12 @@ TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { // Verify that both ways of specifying nested fields -- both dotted notation and nesting -- // can be used together in the same specification. TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - addition.parse(BSON("a.b" << true << "a.c" << true << "a.W" + addition.parse(expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << "W" << "a.X" << "X" @@ -462,8 +501,10 @@ TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { // Verify that adding nested fields preserves the addition order in the spec. TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("b.d" + addition.parse(expCtx, + BSON("b.d" << "FIRST" << "b.c" << "SECOND")); @@ -478,8 +519,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { // Verify that the metadata is kept from the original input document. TEST(ParsedAddFieldsExecutionTest, AlwaysKeepsMetadataFromOriginalDoc) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp index a73a9589ed8..69b196faab3 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp @@ -259,7 +259,7 @@ private: } // namespace std::unique_ptr<ParsedAggregationProjection> ParsedAggregationProjection::create( - const BSONObj& spec) { + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) { // Check that the specification was valid. Status returned is unspecific because validate() // is used by the $addFields stage as well as $project. // If there was an error, uassert with a $project-specific message. @@ -282,7 +282,7 @@ std::unique_ptr<ParsedAggregationProjection> ParsedAggregationProjection::create : static_cast<ParsedAggregationProjection*>(new ParsedExclusionProjection())); // Actually parse the specification. - parsedProject->parse(spec); + parsedProject->parse(expCtx, spec); return parsedProject; } diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.h b/src/mongo/db/pipeline/parsed_aggregation_projection.h index a637365565e..8d92e4ab768 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.h +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.h @@ -41,7 +41,7 @@ namespace mongo { class BSONObj; class Document; -struct ExpressionContext; +class ExpressionContext; namespace parsed_aggregation_projection { @@ -117,7 +117,8 @@ public: * * Throws a UserException if 'spec' is an invalid projection specification. */ - static std::unique_ptr<ParsedAggregationProjection> create(const BSONObj& spec); + static std::unique_ptr<ParsedAggregationProjection> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec); virtual ~ParsedAggregationProjection() = default; @@ -132,7 +133,8 @@ public: * inclusions and exclusions. 'variablesParseState' is used by any contained expressions to * track which variables are defined so that they can later be referenced at execution time. */ - virtual void parse(const BSONObj& spec) = 0; + virtual void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec) = 0; /** * Optimize any expressions contained within this projection. @@ -140,11 +142,6 @@ public: virtual void optimize() {} /** - * Inject the ExpressionContext into any expressions contained within this projection. - */ - virtual void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {} - - /** * Add any dependencies needed by this projection or any sub-expressions to 'deps'. */ virtual DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const { diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp index 5b55071a6bf..71dd1c2378f 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp @@ -36,6 +36,7 @@ #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/json.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,262 +54,296 @@ BSONObj wrapInLiteral(const T& arg) { // TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateFieldNames) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude the same field twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << false)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "b" << false))), UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "b" << false))), - UserException); // Mix of include/exclude and adding a field. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a" << true)), - UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a" << true)), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << wrapInLiteral(0))), UserException); // Adding the same field twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateIds) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << false << "_id" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id" << true)), UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "_id" << false)), UserException); + // Mix of including/excluding and adding _id. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << true)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "_id" << wrapInLiteral(0))), + UserException); + // Adding _id twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldsWithSharedPrefix) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude Fields with a shared prefix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a.b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a.b" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << false << "a" << false)), UserException); // Mix of include/exclude and adding a shared prefix. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a.b" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << false << "a" << wrapInLiteral(0))), + UserException); // Adding a shared prefix twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), + expCtx, BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfIdAndSubFieldsOfId) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id.x" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << false)), - UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id.x" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false << "_id" << false)), UserException); - // Adding _id twice. + // Mix of including/excluding and adding _id. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + expCtx, BSON("_id.x" << false << "_id" << wrapInLiteral(0))), UserException); + + // Adding _id twice. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfInclusionAndExclusion) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Simple mix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << false << "c" << true))), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << BSON("b" << false << "c" << true))), + UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("_id.b" << false << "a.c" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.b" << false << "a.c" << true)), - UserException); // Mix while also adding a field. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), + expCtx, BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), + expCtx, BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), UserException); // Mixing "_id" inclusion with exclusion. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a" << false)), - UserException); - - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "_id" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a.b.c" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "_id" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << true << "a.b.c" << false)), - UserException); -} - -TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "b" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a.b.c" << false)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "b" << false)), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true << "a.b.c" << false)), UserException); +} +TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << wrapInLiteral(1))), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "b" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), + expCtx, BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), + expCtx, BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDottedFieldInSubDocument) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << true))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldNamesStartingWithADollar) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 1)), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 1)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 0)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 1)), + UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 0))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 0))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 1))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 1))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 1)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectTopLevelExpressions) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << BSON_ARRAY(4 << 2))), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << BSON_ARRAY(4 << 2))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectExpressionWithMultipleFieldNames) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + expCtx, BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyProjection) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSONObj()), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSONObj()), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyNestedObject) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << BSONObj())), + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << BSONObj())), + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << BSONObj())), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << BSONObj())), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b" << BSONObj()))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << BSONObj()))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidExpression) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidFieldPath) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Empty field names. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << wrapInLiteral(2))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << false)), UserException); // Not testing field names with null bytes, since that is invalid BSON, and won't make it to the // $project stage without a previous error. // Field names starting with '$'. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$x" << wrapInLiteral(2))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$x" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << false)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << true)); - ParsedAggregationProjection::create(BSON("a.b" << true << "a" << BSON("c" << true))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // @@ -316,102 +351,112 @@ TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { // TEST(ParsedAggregationProjectionType, ShouldDefaultToInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("_id" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectExclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a.b.c" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << true)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldTreatOnlyComputedFieldsAsAnInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("_id" << BSON("x" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("x" << BSON("_id" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldAllowMixOfInclusionAndComputedFields) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); auto parsedProject = - ParsedAggregationProjection::create(BSON("a" << true << "b" << wrapInLiteral(1))); + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("a.b" << true << "a.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = ParsedAggregationProjection::create( - BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); + expCtx, BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldCoerceNumericsToBools) { + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); std::vector<Value> zeros = {Value(0), Value(0LL), Value(0.0), Value(Decimal128(0))}; for (auto&& zero : zeros) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", zero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", zero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } std::vector<Value> nonZeroes = { Value(1), Value(-1), Value(3), Value(1LL), Value(1.0), Value(Decimal128(1))}; for (auto&& nonZero : nonZeroes) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", nonZero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", nonZero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp index 8226d146009..0f800f9a112 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp @@ -143,7 +143,10 @@ Document ParsedExclusionProjection::applyProjection(Document inputDoc) const { return _root->applyProjection(inputDoc); } -void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, size_t depth) { +void ParsedExclusionProjection::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth) { for (auto elem : spec) { const auto fieldName = elem.fieldNameStringData().toString(); @@ -188,7 +191,7 @@ void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, child = child->addOrGetChild(fullPath.fullPath()); } - parse(elem.Obj(), child, depth + 1); + parse(expCtx, elem.Obj(), child, depth + 1); break; } default: { MONGO_UNREACHABLE; } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.h b/src/mongo/db/pipeline/parsed_exclusion_projection.h index ea7b25ac33f..d0988d2d2cb 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.h @@ -108,8 +108,8 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { - parse(spec, _root.get(), 0); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { + parse(expCtx, spec, _root.get(), 0); } /** @@ -134,7 +134,10 @@ private: * Traverses 'spec' and parses each field. Adds any excluded fields at this level to 'node', * and recurses on any sub-objects. */ - void parse(const BSONObj& spec, ExclusionNode* node, size_t depth); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth); // The ExclusionNode tree does most of the execution work once constructed. diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp index 62a23b2c729..76b458260d0 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp @@ -40,6 +40,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" @@ -57,28 +58,32 @@ DEATH_TEST(ExclusionProjection, ShouldRejectComputedField, "Invariant failure fieldName[0] != '$'") { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Top-level expression. - exclusion.parse(BSON("a" << false << "b" << BSON("$literal" << 1))); + exclusion.parse(expCtx, BSON("a" << false << "b" << BSON("$literal" << 1))); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedField, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << true)); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedId, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << true << "a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << true << "a" << false)); } TEST(ExclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); exclusion.parse( - fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); + expCtx, fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); // Converts numbers to bools, converts dotted paths to nested documents. Note order of excluded // fields is subject to change. @@ -107,7 +112,9 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { // later. If there are no later stages, then we will finish the dependency computation // cycle without full knowledge of which fields are needed, and thus include all the fields. ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); DepsTracker deps; exclusion.addDependencies(&deps); @@ -119,7 +126,8 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false << "a" << false << "b.c" << false)); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -131,7 +139,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNestedObj) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -146,7 +155,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNes TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); // More than one field in document. auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -171,7 +181,9 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" << Value(Decimal128(0)))); auto result = @@ -182,7 +194,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("second" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("second" << false)); auto result = exclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"first", 0}, {"third", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -190,7 +203,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}, {"_id", "ID"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -198,7 +212,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false << "_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false << "_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -206,7 +221,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}, {"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -218,7 +234,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id.x" << false << "_id" << BSON("y" << false))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id.x" << false << "_id" << BSON("y" << false))); auto result = exclusion.applyProjection( Document{{"_id", Document{{"x", 1}, {"y", 2}, {"z", 3}}}, {"a", 1}}); auto expectedResult = Document{{"_id", Document{{"z", 3}}}, {"a", 1}}; @@ -227,7 +244,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); // More than one field in sub document. auto result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -252,7 +270,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFieldDoesNotExist) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("sub.target" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("sub.target" << false)); // Should not add the path if it doesn't exist. auto result = exclusion.applyProjection(Document{}); @@ -267,7 +286,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFiel TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementInArray) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); std::vector<Value> nestedValues = { Value(1), @@ -290,8 +310,10 @@ TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementIn TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Exclude all of "a.b", "a.c", "a.d", and "a.e". exclusion.parse( + expCtx, BSON("a.b" << false << "a.c" << false << "a" << BSON("d" << false << "e" << false))); auto result = exclusion.applyProjection( Document{{"a", Document{{"b", 1}, {"c", 2}, {"d", 3}, {"e", 4}, {"f", 5}}}}); @@ -301,7 +323,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); MutableDocument inputDocBuilder(Document{{"_id", "ID"_sd}, {"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp index abac372be7e..511d1b37fff 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp @@ -54,16 +54,6 @@ void InclusionNode::optimize() { } } -void InclusionNode::injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { - for (auto&& expressionIt : _expressions) { - expressionIt.second->injectExpressionContext(expCtx); - } - - for (auto&& childPair : _children) { - childPair.second->injectExpressionContext(expCtx); - } -} - void InclusionNode::serialize(MutableDocument* output, bool explain) const { // Always put "_id" first if it was included (implicitly or explicitly). if (_inclusions.find("_id") != _inclusions.end()) { @@ -252,7 +242,8 @@ void InclusionNode::addPreservedPaths(std::set<std::string>* preservedPaths) con // ParsedInclusionProjection // -void ParsedInclusionProjection::parse(const BSONObj& spec, +void ParsedInclusionProjection::parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, const VariablesParseState& variablesParseState) { // It is illegal to specify a projection with no output fields. bool atLeastOneFieldInOutput = false; @@ -289,7 +280,7 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, } case BSONType::Object: { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. break; } @@ -306,13 +297,14 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. - _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + _root->addComputedField( + FieldPath(elem.fieldName()), + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -343,6 +335,7 @@ Document ParsedInclusionProjection::applyProjection(Document inputDoc, Variables } bool ParsedInclusionProjection::parseObjectAsExpression( + const boost::intrusive_ptr<ExpressionContext>& expCtx, StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { @@ -351,15 +344,17 @@ bool ParsedInclusionProjection::parseObjectAsExpression( // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, - const VariablesParseState& variablesParseState, - InclusionNode* node) { +void ParsedInclusionProjection::parseSubObject( + const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, + const VariablesParseState& variablesParseState, + InclusionNode* node) { for (auto elem : subObj) { invariant(elem.fieldName()[0] != '$'); // Dotted paths in a sub-object have already been disallowed in @@ -381,19 +376,20 @@ void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { break; } auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.h b/src/mongo/db/pipeline/parsed_inclusion_projection.h index 8cd04514b53..2f812af300d 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.h @@ -119,8 +119,6 @@ public: return _pathToNode; } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); - /** * Recursively add all paths that are preserved by this inclusion projection. */ @@ -184,10 +182,10 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique<Variables>(idGenerator.getIdCount()); } @@ -210,10 +208,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::EXHAUSTIVE_FIELDS; @@ -246,7 +240,9 @@ private: * Parses 'spec' to determine which fields to include, which are computed, and whether to * include '_id' or not. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -256,7 +252,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr<ExpressionContext>& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -264,7 +261,8 @@ private: * Traverses 'subObj' and parses each field. Adds any included or computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr<ExpressionContext>& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp index 611f5367e9a..c2610418a40 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,19 +54,23 @@ BSONObj wrapInLiteral(const T& arg) { TEST(InclusionProjection, ShouldThrowWhenParsingInvalidExpression) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("a" << BSON("$gt" << BSON("bad" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); } TEST(InclusionProjection, ShouldRejectProjectionWithNoOutputFields) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("_id" << false)), UserException); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, BSON("_id" << false)), UserException); } TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true << "x.y" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true << "x.y" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -78,7 +83,8 @@ TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -90,7 +96,9 @@ TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << "$a" << "x" << "$z")); @@ -106,7 +114,9 @@ TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("x.y" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("x.y" << "$z")); DepsTracker deps; @@ -123,7 +133,8 @@ TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedInclusionProjection inclusion; - inclusion.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document(fromjson( @@ -136,7 +147,8 @@ TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true)); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document{{"_id", false}, {"a", true}}; @@ -149,7 +161,8 @@ TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -162,7 +175,8 @@ TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -176,10 +190,13 @@ TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON( - "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true - << "e.f" - << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse( + expCtx, + BSON("a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" + << true + << "e.f" + << true)); auto modifiedPaths = inclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kAllExcept); @@ -195,7 +212,9 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWithIdExclusion) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true @@ -222,7 +241,8 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWith TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); // More than one field in document. auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -247,7 +267,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -260,7 +281,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -268,7 +290,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedField TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("first" << true << "second" << true << "third" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("first" << true << "second" << true << "third" << true)); auto inputDoc = Document{{"second", 1}, {"first", 0}, {"third", 2}}; auto result = inclusion.applyProjection(inputDoc); ASSERT_DOCUMENT_EQ(result, inputDoc); @@ -276,7 +299,9 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"firstComputed", "FIRST"_sd}, {"secondComputed", "SECOND"_sd}}; @@ -285,7 +310,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -298,7 +324,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -306,7 +333,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFiel TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << true << "b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << true << "b" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}, {"c", 3}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}; @@ -315,7 +343,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << false)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << false)); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -323,7 +352,8 @@ TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << wrapInLiteral("newId"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << wrapInLiteral("newId"))); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"_id", "newId"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -335,7 +365,8 @@ TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -360,7 +391,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFieldDoesNotExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << true)); // Should not add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -375,7 +407,8 @@ TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFiel TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -399,7 +432,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementIn TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Other fields exist in sub document, one of which is the specified field. auto result = inclusion.applyProjection(Document{{"sub", Document{{"target", 1}, {"c", 2}}}}); @@ -419,7 +453,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDoesntExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -433,7 +468,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDo TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayToComputedField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b.c.d" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b.c.d" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -448,7 +484,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayTo TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << wrapInLiteral("COMPUTED"))); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -471,7 +508,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElement TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); vector<Value> nestedValues = {Value(1), Value(Document{}), @@ -498,7 +536,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachEl TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); auto result = inclusion.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", 1}, {"Z", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -506,13 +545,16 @@ TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedInclusionProjection inclusion; + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - inclusion.parse(BSON( - "a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" << wrapInLiteral("X") - << "a" - << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" - << wrapInLiteral("Z")))); + inclusion.parse( + expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" + << wrapInLiteral("X") + << "a" + << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" + << wrapInLiteral("Z")))); auto result = inclusion.applyProjection(Document{ {"a", Document{{"b", "b"_sd}, {"c", "c"_sd}, {"d", "d"_sd}, {"e", "e"_sd}, {"f", "f"_sd}}}}); @@ -530,7 +572,9 @@ TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"a", "FIRST"_sd}, {"b", Document{{"c", "SECOND"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -538,7 +582,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpe TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", Document{{"c", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -557,7 +602,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppearAfterInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"b", 1}, {"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", "NEW"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -572,7 +618,8 @@ TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppea TEST(InclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/pipeline.cpp b/src/mongo/db/pipeline/pipeline.cpp index b296d3f85df..8bb745fba13 100644 --- a/src/mongo/db/pipeline/pipeline.cpp +++ b/src/mongo/db/pipeline/pipeline.cpp @@ -165,13 +165,6 @@ void Pipeline::reattachToOperationContext(OperationContext* opCtx) { } } -void Pipeline::injectExpressionContext(const intrusive_ptr<ExpressionContext>& expCtx) { - pCtx = expCtx; - for (auto&& stage : _sources) { - stage->injectExpressionContext(pCtx); - } -} - intrusive_ptr<Pipeline> Pipeline::splitForSharded() { // Create and initialize the shard spec we'll return. We start with an empty pipeline on the // shards and all work being done in the merger. Optimizations can move operations between diff --git a/src/mongo/db/pipeline/pipeline.h b/src/mongo/db/pipeline/pipeline.h index aff72fa40bc..fe4ca1de424 100644 --- a/src/mongo/db/pipeline/pipeline.h +++ b/src/mongo/db/pipeline/pipeline.h @@ -45,7 +45,7 @@ class BSONObjBuilder; class CollatorInterface; class DocumentSource; class OperationContext; -struct ExpressionContext; +class ExpressionContext; /** * A Pipeline object represents a list of DocumentSources and is responsible for optimizing the @@ -129,12 +129,6 @@ public: void optimizePipeline(); /** - * Propagates a reference to the ExpressionContext to all of the pipeline's contained stages and - * expressions. - */ - void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); - - /** * Returns any other collections involved in the pipeline in addition to the collection the * aggregation is run on. */ diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index bb681edd3f1..b2a7bee2fa8 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -188,7 +188,6 @@ public: return pipeline.getStatus(); } - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); AutoGetCollectionForRead autoColl(expCtx->opCtx, expCtx->ns); @@ -294,7 +293,8 @@ StatusWith<std::unique_ptr<PlanExecutor>> attemptToGetExecutor( // that the user omitted. // // If pipeline has a null collator (representing the "simple" collation), we simply set the - // collation option to the original user BSON. + // collation option to the original user BSON, which is either the empty object (unspecified), + // or the specification for the "simple" collation. qr->setCollation(pExpCtx->getCollator() ? pExpCtx->getCollator()->getSpec().toBSON() : pExpCtx->collation); diff --git a/src/mongo/db/pipeline/pipeline_d.h b/src/mongo/db/pipeline/pipeline_d.h index 35d93a4cfd2..5be4893cf3b 100644 --- a/src/mongo/db/pipeline/pipeline_d.h +++ b/src/mongo/db/pipeline/pipeline_d.h @@ -38,7 +38,7 @@ namespace mongo { class Collection; class DocumentSourceCursor; class DocumentSourceSort; -struct ExpressionContext; +class ExpressionContext; class OperationContext; class Pipeline; class PlanExecutor; diff --git a/src/mongo/db/pipeline/pipeline_test.cpp b/src/mongo/db/pipeline/pipeline_test.cpp index 41cd6e567d3..dba575017ce 100644 --- a/src/mongo/db/pipeline/pipeline_test.cpp +++ b/src/mongo/db/pipeline/pipeline_test.cpp @@ -39,7 +39,7 @@ #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/collation/collator_interface_mock.h" @@ -86,7 +86,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr<ExpressionContextForTest> ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector<BSONObj>{}}); + auto outputPipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); outputPipe->optimizePipeline(); @@ -240,17 +247,17 @@ class MoveMatchBeforeSort : public Base { class LookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -259,17 +266,17 @@ class LookupShouldCoalesceWithUnwindOnAs : public Base { class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -278,18 +285,18 @@ class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -298,13 +305,13 @@ class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -313,51 +320,59 @@ class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class LookupShouldSwapWithMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0}}]"; } string outputPipeJson() { return "[{$match: {independent: 0}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}]"; } }; class LookupShouldSplitMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0, asField: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 0}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {asField: {$eq: 3}}}]"; } }; class LookupShouldNotAbsorbMatchOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } }; class LookupShouldAbsorbUnwindMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: '$asField'}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}, " " matching: {subfield: {$eq: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: {path: '$asField'}}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -365,14 +380,15 @@ class LookupShouldAbsorbUnwindMatch : public Base { class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {'asField.subfield': {$eq: 1}, independentField: {$gt: 2}}}]"; } string outputPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -386,7 +402,8 @@ class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { } string serializedPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -397,19 +414,21 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { // the $lookup, and if any child of the $or is independent of the 'asField', then the $match // cannot be absorbed by the $lookup. string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; @@ -418,14 +437,15 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {index: 0, 'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -438,7 +458,8 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {$and: [{index: {$eq: 0}}, {'asField.value': {$gt: 0}}]}}]"; } @@ -446,14 +467,15 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -465,7 +487,8 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}}}]"; } @@ -473,13 +496,13 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas class LookupDoesNotAbsorbElemMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: '$x'}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string outputPipeJson() { return "[{$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'x', " " localField: 'y', " " foreignField: 'z', " @@ -491,7 +514,7 @@ class LookupDoesNotAbsorbElemMatch : public Base { " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x'}}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } @@ -499,35 +522,35 @@ class LookupDoesNotAbsorbElemMatch : public Base { class LookupDoesSwapWithMatchOnLocalField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {y: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {y: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesSwapWithMatchOnFieldWithSameNameAsForeignField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {z: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {z: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesNotAbsorbUnwindOnSubfieldOfAsButStillMovesMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x.subfield'}}, " " {$match: {'independent': 2, 'x.dependent': 2}}]"; } string outputPipeJson() { return "[{$match: {'independent': {$eq: 2}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {'x.dependent': {$eq: 2}}}, " " {$unwind: {path: '$x.subfield'}}]"; } @@ -639,58 +662,61 @@ class UnwindBeforeDoubleMatchShouldRepeatedlyOptimize : public Base { class GraphLookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$out'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: " - "false}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out'}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: true}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: false, " - " includeArrayIndex: 'index'}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false, " + " includeArrayIndex: 'index'}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', " " startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } @@ -698,14 +724,14 @@ class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$nottherightthing'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$nottherightthing'}}]"; } }; @@ -713,7 +739,7 @@ class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class GraphLookupShouldSwapWithMatch : public Base { string inputPipeJson() { return "[{$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -725,7 +751,7 @@ class GraphLookupShouldSwapWithMatch : public Base { string outputPipeJson() { return "[{$match: {independent: 'x'}}," " {$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -863,7 +889,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr<ExpressionContextForTest> ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector<BSONObj>{}}); + mergePipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); mergePipe->optimizePipeline(); @@ -1057,7 +1090,7 @@ namespace coalesceLookUpAndUnwind { class ShouldCoalesceUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -1066,14 +1099,14 @@ class ShouldCoalesceUnwindOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } }; class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -1082,14 +1115,14 @@ class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } }; class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -1098,7 +1131,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } @@ -1106,7 +1139,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { class ShouldNotCoalesceUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1115,7 +1148,7 @@ class ShouldNotCoalesceUnwindNotOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1170,14 +1203,14 @@ class LookUp : public needsPrimaryShardMergerBase { return true; } string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } string shardPipeJson() { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } }; @@ -1192,7 +1225,7 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { OperationContextNoop _opCtx; const std::vector<BSONObj> rawPipeline = { fromjson("{$geoNear: {distanceField: 'd', near: [0, 0], query: {a: 1}}}")}; - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext( + intrusive_ptr<ExpressionContextForTest> ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 1)); @@ -1201,28 +1234,13 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { TEST(PipelineInitialSource, MatchInitialQuery) { OperationContextNoop _opCtx; const std::vector<BSONObj> rawPipeline = {fromjson("{$match: {'a': 4}}")}; - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext( + intrusive_ptr<ExpressionContextForTest> ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 4)); } -TEST(PipelineInitialSource, ParseCollation) { - QueryTestServiceContext serviceContext; - auto opCtx = serviceContext.makeOperationContext(); - - const BSONObj inputBson = - fromjson("{pipeline: [{$match: {a: 'abc'}}], collation: {locale: 'reverse'}}"); - auto request = AggregationRequest::parseFromBSON(NamespaceString("a.collection"), inputBson); - ASSERT_OK(request.getStatus()); - - intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(opCtx.get(), request.getValue()); - ASSERT(ctx->getCollator()); - CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); - ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->getCollator(), &collator)); -} - namespace Dependencies { using PipelineDependenciesTest = AggregationContextFixture; diff --git a/src/mongo/db/views/view_catalog.cpp b/src/mongo/db/views/view_catalog.cpp index 379f0364839..d8e0e842824 100644 --- a/src/mongo/db/views/view_catalog.cpp +++ b/src/mongo/db/views/view_catalog.cpp @@ -43,6 +43,7 @@ #include "mongo/db/pipeline/aggregation_request.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/lite_parsed_pipeline.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/collation/collator_factory_interface.h" #include "mongo/db/server_options.h" @@ -156,10 +157,24 @@ Status ViewCatalog::_upsertIntoGraph(OperationContext* txn, const ViewDefinition // Performs the insert into the graph. auto doInsert = [this, &txn](const ViewDefinition& viewDef, bool needsValidation) -> Status { - // Parse the pipeline for this view to get the namespaces it references. + // Make a LiteParsedPipeline to determine the namespaces referenced by this pipeline. AggregationRequest request(viewDef.viewOn(), viewDef.pipeline()); - boost::intrusive_ptr<ExpressionContext> expCtx = new ExpressionContext(txn, request); - expCtx->setCollator(CollatorInterface::cloneCollator(viewDef.defaultCollator())); + const LiteParsedPipeline liteParsedPipeline(request); + const auto involvedNamespaces = liteParsedPipeline.getInvolvedNamespaces(); + + // Verify that this is a legitimate pipeline specification by making sure it parses + // correctly. In order to parse a pipeline we need to resolve any namespaces involved to a + // collection and a pipeline, but in this case we don't need this map to be accurate since + // we will not be evaluating the pipeline. + StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces; + for (auto&& nss : liteParsedPipeline.getInvolvedNamespaces()) { + resolvedNamespaces[nss.coll()] = {nss, {}}; + } + boost::intrusive_ptr<ExpressionContext> expCtx = + new ExpressionContext(txn, + request, + CollatorInterface::cloneCollator(viewDef.defaultCollator()), + std::move(resolvedNamespaces)); auto pipelineStatus = Pipeline::parse(viewDef.pipeline(), expCtx); if (!pipelineStatus.isOK()) { uassert(40255, @@ -169,7 +184,7 @@ Status ViewCatalog::_upsertIntoGraph(OperationContext* txn, const ViewDefinition return pipelineStatus.getStatus(); } - std::vector<NamespaceString> refs = pipelineStatus.getValue()->getInvolvedCollections(); + std::vector<NamespaceString> refs(involvedNamespaces.begin(), involvedNamespaces.end()); refs.push_back(viewDef.viewOn()); int pipelineSize = 0; diff --git a/src/mongo/dbtests/documentsourcetests.cpp b/src/mongo/dbtests/documentsourcetests.cpp index 1c99ca82663..51d34334ecd 100644 --- a/src/mongo/dbtests/documentsourcetests.cpp +++ b/src/mongo/dbtests/documentsourcetests.cpp @@ -41,7 +41,7 @@ #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_cursor.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/get_executor.h" #include "mongo/db/query/plan_executor.h" @@ -86,7 +86,7 @@ using mongo::DocumentSourceCursor; class Base : public CollectionBase { public: - Base() : _ctx(new ExpressionContext(&_opCtx, AggregationRequest(nss, {}))) { + Base() : _ctx(new ExpressionContextForTest(&_opCtx, AggregationRequest(nss, {}))) { _ctx->tempDir = storageGlobalParams.dbpath + "/_tmp"; } @@ -113,7 +113,7 @@ protected: _source = DocumentSourceCursor::create(nss.ns(), std::move(exec), _ctx); } - intrusive_ptr<ExpressionContext> ctx() { + intrusive_ptr<ExpressionContextForTest> ctx() { return _ctx; } @@ -123,7 +123,7 @@ protected: private: // It is important that these are ordered to ensure correct destruction order. - intrusive_ptr<ExpressionContext> _ctx; + intrusive_ptr<ExpressionContextForTest> _ctx; intrusive_ptr<DocumentSourceCursor> _source; }; diff --git a/src/mongo/dbtests/query_plan_executor.cpp b/src/mongo/dbtests/query_plan_executor.cpp index 70b20566a7b..e3e19873228 100644 --- a/src/mongo/dbtests/query_plan_executor.cpp +++ b/src/mongo/dbtests/query_plan_executor.cpp @@ -45,7 +45,7 @@ #include "mongo/db/matcher/expression_parser.h" #include "mongo/db/matcher/extensions_callback_disallow_extensions.h" #include "mongo/db/pipeline/document_source_cursor.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/plan_executor.h" #include "mongo/db/query/query_solution.h" @@ -281,8 +281,8 @@ public: // Create the aggregation pipeline. std::vector<BSONObj> rawPipeline = {fromjson("{$match: {a: {$gte: 7, $lte: 10}}}")}; - boost::intrusive_ptr<ExpressionContext> expCtx = - new ExpressionContext(&_txn, AggregationRequest(nss, rawPipeline)); + boost::intrusive_ptr<ExpressionContextForTest> expCtx = + new ExpressionContextForTest(&_txn, AggregationRequest(nss, rawPipeline)); // Create an "inner" plan executor and register it with the cursor manager so that it can // get notified when the collection is dropped. diff --git a/src/mongo/s/commands/cluster_aggregate.cpp b/src/mongo/s/commands/cluster_aggregate.cpp index fd9e787579e..1b687de25dd 100644 --- a/src/mongo/s/commands/cluster_aggregate.cpp +++ b/src/mongo/s/commands/cluster_aggregate.cpp @@ -84,41 +84,45 @@ Status ClusterAggregate::runAggregate(OperationContext* txn, return request.getStatus(); } - boost::intrusive_ptr<ExpressionContext> mergeCtx = - new ExpressionContext(txn, request.getValue()); - mergeCtx->inRouter = true; - // explicitly *not* setting mergeCtx->tempDir - + // Determine the appropriate collation and 'resolve' involved namespaces to make the + // ExpressionContext. + + // We won't try to execute anything on a mongos, but we still have to populate this map so that + // any $lookups, etc. will be able to have a resolved view definition. It's okay that this is + // incorrect, we will repopulate the real resolved namespace map on the mongod. Note that we + // need to check if any involved collections are sharded before forwarding an aggregation + // command on an unsharded collection. + StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces; LiteParsedPipeline liteParsedPipeline(request.getValue()); - for (auto&& ns : liteParsedPipeline.getInvolvedNamespaces()) { uassert(28769, str::stream() << ns.ns() << " cannot be sharded", !conf->isSharded(ns.ns())); - // We won't try to execute anything on a mongos, but we still have to populate this map - // so that any $lookups etc will be able to have a resolved view definition. It's okay - // that this is incorrect, we will repopulate the real resolved namespace map on the - // mongod. - mergeCtx->resolvedNamespaces[ns.coll()] = {ns, std::vector<BSONObj>{}}; + resolvedNamespaces[ns.coll()] = {ns, std::vector<BSONObj>{}}; } if (!conf->isSharded(namespaces.executionNss.ns())) { return aggPassthrough(txn, namespaces, conf, cmdObj, result, options); } - auto chunkMgr = conf->getChunkManager(txn, namespaces.executionNss.ns()); - // If there was no collation specified, but there is a default collation for the collation, - // use that. - if (request.getValue().getCollation().isEmpty() && chunkMgr->getDefaultCollator()) { - mergeCtx->setCollator(chunkMgr->getDefaultCollator()->clone()); + std::unique_ptr<CollatorInterface> collation; + if (!request.getValue().getCollation().isEmpty()) { + collation = uassertStatusOK(CollatorFactoryInterface::get(txn->getServiceContext()) + ->makeFromBSON(request.getValue().getCollation())); + } else if (chunkMgr->getDefaultCollator()) { + collation = chunkMgr->getDefaultCollator()->clone(); } + boost::intrusive_ptr<ExpressionContext> mergeCtx = new ExpressionContext( + txn, request.getValue(), std::move(collation), std::move(resolvedNamespaces)); + mergeCtx->inRouter = true; + // explicitly *not* setting mergeCtx->tempDir + // Parse and optimize the pipeline specification. auto pipeline = Pipeline::parse(request.getValue().getPipeline(), mergeCtx); if (!pipeline.isOK()) { return pipeline.getStatus(); } - pipeline.getValue()->injectExpressionContext(mergeCtx); pipeline.getValue()->optimizePipeline(); // If the first $match stage is an exact match on the shard key (with a simple collation or |