From 37e720678f6e468726c6cc775a5dc898d080f0f3 Mon Sep 17 00:00:00 2001 From: Charlie Swanson Date: Tue, 13 Dec 2016 10:15:08 -0500 Subject: SERVER-25535 Remove injectExpressionContext(). These methods were formally used to propagate a new ExpressionContext to stages, accumulators, or expressions which potentially needed to comparisons. Originally, this was necessary since Pipeline parsing happened outside of the collection lock and thus could not determine if there was a default collation on the collection. This meant that the collation could change after parsing and any operators that might compare strings would need to know about it. We have since moved parsing within the lock, so the collation can be known at parse time and the ExpressionContext should not change. This patch requires an ExpressionContext at construction time, and disallows changing the collation on an ExpressionContext. --- src/mongo/db/pipeline/accumulation_statement.cpp | 6 +- src/mongo/db/pipeline/accumulation_statement.h | 6 +- src/mongo/db/pipeline/accumulator.h | 86 ++--- src/mongo/db/pipeline/accumulator_add_to_set.cpp | 18 +- src/mongo/db/pipeline/accumulator_avg.cpp | 8 +- src/mongo/db/pipeline/accumulator_first.cpp | 8 +- src/mongo/db/pipeline/accumulator_last.cpp | 8 +- src/mongo/db/pipeline/accumulator_min_max.cpp | 14 +- src/mongo/db/pipeline/accumulator_push.cpp | 8 +- src/mongo/db/pipeline/accumulator_std_dev.cpp | 14 +- src/mongo/db/pipeline/accumulator_sum.cpp | 8 +- src/mongo/db/pipeline/accumulator_test.cpp | 35 +- .../db/pipeline/aggregation_context_fixture.h | 8 +- src/mongo/db/pipeline/document_source.cpp | 2 +- src/mongo/db/pipeline/document_source.h | 23 +- .../db/pipeline/document_source_add_fields.cpp | 3 +- src/mongo/db/pipeline/document_source_add_fields.h | 2 + src/mongo/db/pipeline/document_source_bucket.cpp | 12 +- src/mongo/db/pipeline/document_source_bucket.h | 1 + .../db/pipeline/document_source_bucket_auto.cpp | 26 +- .../db/pipeline/document_source_bucket_auto.h | 5 +- .../pipeline/document_source_bucket_auto_test.cpp | 8 +- .../db/pipeline/document_source_bucket_test.cpp | 3 - src/mongo/db/pipeline/document_source_count.h | 1 + src/mongo/db/pipeline/document_source_cursor.cpp | 7 - src/mongo/db/pipeline/document_source_cursor.h | 3 - src/mongo/db/pipeline/document_source_facet.cpp | 6 - src/mongo/db/pipeline/document_source_facet.h | 7 +- src/mongo/db/pipeline/document_source_geo_near.cpp | 1 - .../db/pipeline/document_source_graph_lookup.cpp | 58 ++- .../db/pipeline/document_source_graph_lookup.h | 7 +- .../pipeline/document_source_graph_lookup_test.cpp | 118 +++--- src/mongo/db/pipeline/document_source_group.cpp | 50 +-- src/mongo/db/pipeline/document_source_group.h | 3 - .../db/pipeline/document_source_group_test.cpp | 26 +- src/mongo/db/pipeline/document_source_limit.cpp | 1 - src/mongo/db/pipeline/document_source_lookup.cpp | 24 +- src/mongo/db/pipeline/document_source_lookup.h | 3 - .../db/pipeline/document_source_lookup_test.cpp | 24 +- src/mongo/db/pipeline/document_source_match.cpp | 7 +- src/mongo/db/pipeline/document_source_match.h | 4 +- .../db/pipeline/document_source_merge_cursors.cpp | 1 - src/mongo/db/pipeline/document_source_mock.cpp | 4 +- src/mongo/db/pipeline/document_source_mock.h | 4 - src/mongo/db/pipeline/document_source_project.cpp | 3 +- src/mongo/db/pipeline/document_source_project.h | 2 + src/mongo/db/pipeline/document_source_redact.cpp | 6 +- src/mongo/db/pipeline/document_source_redact.h | 2 - .../db/pipeline/document_source_replace_root.cpp | 15 +- .../db/pipeline/document_source_replace_root.h | 2 + src/mongo/db/pipeline/document_source_sample.cpp | 4 - src/mongo/db/pipeline/document_source_sample.h | 2 - .../document_source_sample_from_random_cursor.cpp | 10 +- .../document_source_sample_from_random_cursor.h | 7 +- ...ument_source_single_document_transformation.cpp | 4 - ...ocument_source_single_document_transformation.h | 3 - src/mongo/db/pipeline/document_source_skip.cpp | 1 - src/mongo/db/pipeline/document_source_sort.cpp | 5 +- .../db/pipeline/document_source_sort_by_count.h | 2 + .../db/pipeline/document_source_tee_consumer.h | 2 +- src/mongo/db/pipeline/document_source_unwind.cpp | 1 - .../db/pipeline/document_source_unwind_test.cpp | 9 +- src/mongo/db/pipeline/expression.cpp | 351 ++++++++--------- src/mongo/db/pipeline/expression.h | 428 ++++++++++++++++----- src/mongo/db/pipeline/expression_context.cpp | 26 +- src/mongo/db/pipeline/expression_context.h | 52 ++- .../db/pipeline/expression_context_for_test.h | 64 +++ src/mongo/db/pipeline/expression_test.cpp | 293 ++++++++------ src/mongo/db/pipeline/granularity_rounder.cpp | 4 +- src/mongo/db/pipeline/granularity_rounder.h | 34 +- .../pipeline/granularity_rounder_powers_of_two.cpp | 9 +- .../granularity_rounder_powers_of_two_test.cpp | 43 ++- .../granularity_rounder_preferred_numbers.cpp | 61 +-- .../granularity_rounder_preferred_numbers_test.cpp | 28 +- src/mongo/db/pipeline/lookup_set_cache.h | 31 +- src/mongo/db/pipeline/lookup_set_cache_test.cpp | 26 +- src/mongo/db/pipeline/parsed_add_fields.cpp | 28 +- src/mongo/db/pipeline/parsed_add_fields.h | 21 +- src/mongo/db/pipeline/parsed_add_fields_test.cpp | 136 ++++--- .../db/pipeline/parsed_aggregation_projection.cpp | 4 +- .../db/pipeline/parsed_aggregation_projection.h | 13 +- .../parsed_aggregation_projection_test.cpp | 313 ++++++++------- .../db/pipeline/parsed_exclusion_projection.cpp | 7 +- .../db/pipeline/parsed_exclusion_projection.h | 9 +- .../pipeline/parsed_exclusion_projection_test.cpp | 59 ++- .../db/pipeline/parsed_inclusion_projection.cpp | 38 +- .../db/pipeline/parsed_inclusion_projection.h | 20 +- .../pipeline/parsed_inclusion_projection_test.cpp | 133 ++++--- src/mongo/db/pipeline/pipeline.cpp | 7 - src/mongo/db/pipeline/pipeline.h | 8 +- src/mongo/db/pipeline/pipeline_d.cpp | 4 +- src/mongo/db/pipeline/pipeline_d.h | 2 +- src/mongo/db/pipeline/pipeline_test.cpp | 210 +++++----- 93 files changed, 1847 insertions(+), 1366 deletions(-) create mode 100644 src/mongo/db/pipeline/expression_context_for_test.h (limited to 'src/mongo/db/pipeline') diff --git a/src/mongo/db/pipeline/accumulation_statement.cpp b/src/mongo/db/pipeline/accumulation_statement.cpp index b191a78fdf1..9ac394b0018 100644 --- a/src/mongo/db/pipeline/accumulation_statement.cpp +++ b/src/mongo/db/pipeline/accumulation_statement.cpp @@ -64,7 +64,9 @@ Accumulator::Factory AccumulationStatement::getFactory(StringData name) { } AccumulationStatement AccumulationStatement::parseAccumulationStatement( - const BSONElement& elem, const VariablesParseState& vps) { + const boost::intrusive_ptr& expCtx, + const BSONElement& elem, + const VariablesParseState& vps) { auto fieldName = elem.fieldNameStringData(); uassert(40234, str::stream() << "The field '" << fieldName << "' must be an accumulator object", @@ -91,7 +93,7 @@ AccumulationStatement AccumulationStatement::parseAccumulationStatement( return {fieldName.toString(), AccumulationStatement::getFactory(accName), - Expression::parseOperand(specElem, vps)}; + Expression::parseOperand(expCtx, specElem, vps)}; } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h index 5a7c578ff65..92f7906ed9e 100644 --- a/src/mongo/db/pipeline/accumulation_statement.h +++ b/src/mongo/db/pipeline/accumulation_statement.h @@ -72,8 +72,10 @@ public: * * Throws a UserException if parsing fails. */ - static AccumulationStatement parseAccumulationStatement(const BSONElement& elem, - const VariablesParseState& vps); + static AccumulationStatement parseAccumulationStatement( + const boost::intrusive_ptr& expCtx, + const BSONElement& elem, + const VariablesParseState& vps); /** * Registers an Accumulator with a parsing function, so that when an accumulator with the given diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index e0a2dbdb11a..28645e90a9b 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -48,9 +48,10 @@ namespace mongo { class Accumulator : public RefCountable { public: - using Factory = boost::intrusive_ptr (*)(); + using Factory = boost::intrusive_ptr (*)( + const boost::intrusive_ptr& expCtx); - Accumulator() = default; + Accumulator(const boost::intrusive_ptr& expCtx) : _expCtx(expCtx) {} /** Process input and update internal state. * merging should be true when processing outputs from getValue(true). @@ -84,26 +85,10 @@ public: return false; } - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Accumulator. - * Construction of accumulators is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: /// Update subclass's internal state based on input virtual void processInternal(const Value& input, bool merging) = 0; - /** - * Accumulators which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} - const boost::intrusive_ptr& getExpressionContext() const { return _expCtx; } @@ -118,14 +103,15 @@ private: class AccumulatorAddToSet final : public Accumulator { public: - AccumulatorAddToSet(); + explicit AccumulatorAddToSet(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); bool isAssociative() const final { return true; @@ -135,26 +121,22 @@ public: return true; } - void doInjectExpressionContext() final; - private: - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected, since this set must use the comparator's definition of - // equality. - boost::optional _set; + ValueUnorderedSet _set; }; class AccumulatorFirst final : public Accumulator { public: - AccumulatorFirst(); + explicit AccumulatorFirst(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); private: bool _haveFirst; @@ -164,14 +146,15 @@ private: class AccumulatorLast final : public Accumulator { public: - AccumulatorLast(); + explicit AccumulatorLast(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); private: Value _last; @@ -180,14 +163,15 @@ private: class AccumulatorSum final : public Accumulator { public: - AccumulatorSum(); + explicit AccumulatorSum(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); bool isAssociative() const final { return true; @@ -211,7 +195,7 @@ public: MAX = -1, // Used to "scale" comparison. }; - explicit AccumulatorMinMax(Sense sense); + AccumulatorMinMax(const boost::intrusive_ptr& expCtx, Sense sense); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -233,27 +217,32 @@ private: class AccumulatorMax final : public AccumulatorMinMax { public: - AccumulatorMax() : AccumulatorMinMax(MAX) {} - static boost::intrusive_ptr create(); + explicit AccumulatorMax(const boost::intrusive_ptr& expCtx) + : AccumulatorMinMax(expCtx, MAX) {} + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); }; class AccumulatorMin final : public AccumulatorMinMax { public: - AccumulatorMin() : AccumulatorMinMax(MIN) {} - static boost::intrusive_ptr create(); + explicit AccumulatorMin(const boost::intrusive_ptr& expCtx) + : AccumulatorMinMax(expCtx, MIN) {} + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); }; class AccumulatorPush final : public Accumulator { public: - AccumulatorPush(); + explicit AccumulatorPush(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); private: std::vector vpValue; @@ -262,14 +251,15 @@ private: class AccumulatorAvg final : public Accumulator { public: - AccumulatorAvg(); + explicit AccumulatorAvg(const boost::intrusive_ptr& expCtx); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; const char* getOpName() const final; void reset() final; - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); private: /** @@ -287,7 +277,7 @@ private: class AccumulatorStdDev : public Accumulator { public: - explicit AccumulatorStdDev(bool isSamp); + AccumulatorStdDev(const boost::intrusive_ptr& expCtx, bool isSamp); void processInternal(const Value& input, bool merging) final; Value getValue(bool toBeMerged) const final; @@ -303,13 +293,17 @@ private: class AccumulatorStdDevPop final : public AccumulatorStdDev { public: - AccumulatorStdDevPop() : AccumulatorStdDev(false) {} - static boost::intrusive_ptr create(); + explicit AccumulatorStdDevPop(const boost::intrusive_ptr& expCtx) + : AccumulatorStdDev(expCtx, false) {} + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); }; class AccumulatorStdDevSamp final : public AccumulatorStdDev { public: - AccumulatorStdDevSamp() : AccumulatorStdDev(true) {} - static boost::intrusive_ptr create(); + explicit AccumulatorStdDevSamp(const boost::intrusive_ptr& expCtx) + : AccumulatorStdDev(expCtx, true) {} + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); }; } diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index 6a54e846a85..00845b32fc8 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -48,7 +48,7 @@ const char* AccumulatorAddToSet::getOpName() const { void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { if (!merging) { if (!input.missing()) { - bool inserted = _set->insert(input).second; + bool inserted = _set.insert(input).second; if (inserted) { _memUsageBytes += input.getApproximateSize(); } @@ -62,7 +62,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { const vector& array = input.getArray(); for (size_t i = 0; i < array.size(); i++) { - bool inserted = _set->insert(array[i]).second; + bool inserted = _set.insert(array[i]).second; if (inserted) { _memUsageBytes += array[i].getApproximateSize(); } @@ -71,10 +71,11 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { } Value AccumulatorAddToSet::getValue(bool toBeMerged) const { - return Value(vector(_set->begin(), _set->end())); + return Value(vector(_set.begin(), _set.end())); } -AccumulatorAddToSet::AccumulatorAddToSet() { +AccumulatorAddToSet::AccumulatorAddToSet(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx), _set(expCtx->getValueComparator().makeUnorderedValueSet()) { _memUsageBytes = sizeof(*this); } @@ -83,12 +84,9 @@ void AccumulatorAddToSet::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr AccumulatorAddToSet::create() { - return new AccumulatorAddToSet(); -} - -void AccumulatorAddToSet::doInjectExpressionContext() { - _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet(); +intrusive_ptr AccumulatorAddToSet::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorAddToSet(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index c977ba99c15..da0c08fce9c 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -92,8 +92,9 @@ void AccumulatorAvg::processInternal(const Value& input, bool merging) { _count++; } -intrusive_ptr AccumulatorAvg::create() { - return new AccumulatorAvg(); +intrusive_ptr AccumulatorAvg::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorAvg(expCtx); } Decimal128 AccumulatorAvg::_getDecimalTotal() const { @@ -120,7 +121,8 @@ Value AccumulatorAvg::getValue(bool toBeMerged) const { return Value(_nonDecimalTotal.getDouble() / static_cast(_count)); } -AccumulatorAvg::AccumulatorAvg() : _isDecimal(false), _count(0) { +AccumulatorAvg::AccumulatorAvg(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx), _isDecimal(false), _count(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_first.cpp b/src/mongo/db/pipeline/accumulator_first.cpp index 68ea563ff42..3092bd0e9da 100644 --- a/src/mongo/db/pipeline/accumulator_first.cpp +++ b/src/mongo/db/pipeline/accumulator_first.cpp @@ -57,7 +57,8 @@ Value AccumulatorFirst::getValue(bool toBeMerged) const { return _first; } -AccumulatorFirst::AccumulatorFirst() : _haveFirst(false) { +AccumulatorFirst::AccumulatorFirst(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx), _haveFirst(false) { _memUsageBytes = sizeof(*this); } @@ -68,7 +69,8 @@ void AccumulatorFirst::reset() { } -intrusive_ptr AccumulatorFirst::create() { - return new AccumulatorFirst(); +intrusive_ptr AccumulatorFirst::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorFirst(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_last.cpp b/src/mongo/db/pipeline/accumulator_last.cpp index 64f545f7db1..5f465b1bf4e 100644 --- a/src/mongo/db/pipeline/accumulator_last.cpp +++ b/src/mongo/db/pipeline/accumulator_last.cpp @@ -53,7 +53,8 @@ Value AccumulatorLast::getValue(bool toBeMerged) const { return _last; } -AccumulatorLast::AccumulatorLast() { +AccumulatorLast::AccumulatorLast(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -62,7 +63,8 @@ void AccumulatorLast::reset() { _last = Value(); } -intrusive_ptr AccumulatorLast::create() { - return new AccumulatorLast(); +intrusive_ptr AccumulatorLast::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorLast(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index 08db8b9e0fc..f5c4b254b88 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -68,7 +68,9 @@ Value AccumulatorMinMax::getValue(bool toBeMerged) const { return _val; } -AccumulatorMinMax::AccumulatorMinMax(Sense sense) : _sense(sense) { +AccumulatorMinMax::AccumulatorMinMax(const boost::intrusive_ptr& expCtx, + Sense sense) + : Accumulator(expCtx), _sense(sense) { _memUsageBytes = sizeof(*this); } @@ -77,11 +79,13 @@ void AccumulatorMinMax::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr AccumulatorMin::create() { - return new AccumulatorMin(); +intrusive_ptr AccumulatorMin::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorMin(expCtx); } -intrusive_ptr AccumulatorMax::create() { - return new AccumulatorMax(); +intrusive_ptr AccumulatorMax::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorMax(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp index 16f722a08f3..9239d8b8a18 100644 --- a/src/mongo/db/pipeline/accumulator_push.cpp +++ b/src/mongo/db/pipeline/accumulator_push.cpp @@ -71,7 +71,8 @@ Value AccumulatorPush::getValue(bool toBeMerged) const { return Value(vpValue); } -AccumulatorPush::AccumulatorPush() { +AccumulatorPush::AccumulatorPush(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx) { _memUsageBytes = sizeof(*this); } @@ -80,7 +81,8 @@ void AccumulatorPush::reset() { _memUsageBytes = sizeof(*this); } -intrusive_ptr AccumulatorPush::create() { - return new AccumulatorPush(); +intrusive_ptr AccumulatorPush::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorPush(expCtx); } } diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 28565250fe3..f45678dce27 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -95,15 +95,19 @@ Value AccumulatorStdDev::getValue(bool toBeMerged) const { } } -intrusive_ptr AccumulatorStdDevSamp::create() { - return new AccumulatorStdDevSamp(); +intrusive_ptr AccumulatorStdDevSamp::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorStdDevSamp(expCtx); } -intrusive_ptr AccumulatorStdDevPop::create() { - return new AccumulatorStdDevPop(); +intrusive_ptr AccumulatorStdDevPop::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorStdDevPop(expCtx); } -AccumulatorStdDev::AccumulatorStdDev(bool isSamp) : _isSamp(isSamp), _count(0), _mean(0), _m2(0) { +AccumulatorStdDev::AccumulatorStdDev(const boost::intrusive_ptr& expCtx, + bool isSamp) + : Accumulator(expCtx), _isSamp(isSamp), _count(0), _mean(0), _m2(0) { // This is a fixed size Accumulator so we never need to update this _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 13aa12fbfca..bbbf596d4a8 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -84,8 +84,9 @@ void AccumulatorSum::processInternal(const Value& input, bool merging) { } } -intrusive_ptr AccumulatorSum::create() { - return new AccumulatorSum(); +intrusive_ptr AccumulatorSum::create( + const boost::intrusive_ptr& expCtx) { + return new AccumulatorSum(expCtx); } Value AccumulatorSum::getValue(bool toBeMerged) const { @@ -133,7 +134,8 @@ Value AccumulatorSum::getValue(bool toBeMerged) const { } } -AccumulatorSum::AccumulatorSum() { +AccumulatorSum::AccumulatorSum(const boost::intrusive_ptr& expCtx) + : Accumulator(expCtx) { // This is a fixed size Accumulator so we never need to update this. _memUsageBytes = sizeof(*this); } diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index 29ab135ebb4..3fcb1d70ad7 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -32,7 +32,7 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/collation/collator_interface_mock.h" #include "mongo/dbtests/dbtests.h" #include "mongo/stdx/memory.h" @@ -57,8 +57,7 @@ static void assertExpectedResults( try { // Asserts that result equals expected result when not sharded. { - boost::intrusive_ptr accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr accum(factory(expCtx)); for (auto&& val : op.first) { accum->process(val, false); } @@ -69,10 +68,8 @@ static void assertExpectedResults( // Asserts that result equals expected result when all input is on one shard. { - boost::intrusive_ptr accum = factory(); - accum->injectExpressionContext(expCtx); - boost::intrusive_ptr shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr accum(factory(expCtx)); + boost::intrusive_ptr shard(factory(expCtx)); for (auto&& val : op.first) { shard->process(val, false); } @@ -84,11 +81,9 @@ static void assertExpectedResults( // Asserts that result equals expected result when each input is on a separate shard. { - boost::intrusive_ptr accum = factory(); - accum->injectExpressionContext(expCtx); + boost::intrusive_ptr accum(factory(expCtx)); for (auto&& val : op.first) { - boost::intrusive_ptr shard = factory(); - shard->injectExpressionContext(expCtx); + boost::intrusive_ptr shard(factory(expCtx)); shard->process(val, false); accum->process(shard->getValue(true), true); } @@ -104,7 +99,7 @@ static void assertExpectedResults( } TEST(Accumulators, Avg) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$avg", expCtx, @@ -160,7 +155,7 @@ TEST(Accumulators, Avg) { } TEST(Accumulators, First) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$first", expCtx, @@ -179,7 +174,7 @@ TEST(Accumulators, First) { } TEST(Accumulators, Last) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$last", expCtx, @@ -198,7 +193,7 @@ TEST(Accumulators, Last) { } TEST(Accumulators, Min) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$min", expCtx, @@ -217,14 +212,14 @@ TEST(Accumulators, Min) { } TEST(Accumulators, MinRespectsCollation) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$min", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("cba"_sd)}}); } TEST(Accumulators, Max) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$max", expCtx, @@ -243,14 +238,14 @@ TEST(Accumulators, Max) { } TEST(Accumulators, MaxRespectsCollation) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique(CollatorInterfaceMock::MockType::kReverseString)); assertExpectedResults("$max", expCtx, {{{Value("abc"_sd), Value("cba"_sd)}, Value("abc"_sd)}}); } TEST(Accumulators, Sum) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); assertExpectedResults( "$sum", expCtx, @@ -340,7 +335,7 @@ TEST(Accumulators, Sum) { } TEST(Accumulators, AddToSetRespectsCollation) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); expCtx->setCollator( stdx::make_unique(CollatorInterfaceMock::MockType::kAlwaysEqual)); assertExpectedResults("$addToSet", diff --git a/src/mongo/db/pipeline/aggregation_context_fixture.h b/src/mongo/db/pipeline/aggregation_context_fixture.h index 75353c46e7f..786f50a9b5a 100644 --- a/src/mongo/db/pipeline/aggregation_context_fixture.h +++ b/src/mongo/db/pipeline/aggregation_context_fixture.h @@ -32,7 +32,7 @@ #include #include "mongo/db/client.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context_noop.h" #include "mongo/stdx/memory.h" @@ -48,16 +48,16 @@ public: AggregationContextFixture() : _queryServiceContext(stdx::make_unique()), _opCtx(_queryServiceContext->makeOperationContext()), - _expCtx(new ExpressionContext( + _expCtx(new ExpressionContextForTest( _opCtx.get(), AggregationRequest(NamespaceString("unittests.pipeline_test"), {}))) {} - boost::intrusive_ptr getExpCtx() { + boost::intrusive_ptr getExpCtx() { return _expCtx.get(); } private: std::unique_ptr _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - boost::intrusive_ptr _expCtx; + boost::intrusive_ptr _expCtx; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source.cpp b/src/mongo/db/pipeline/document_source.cpp index b83ec851ebe..440dc762175 100644 --- a/src/mongo/db/pipeline/document_source.cpp +++ b/src/mongo/db/pipeline/document_source.cpp @@ -61,7 +61,7 @@ void DocumentSource::registerParser(string name, Parser parser) { } vector> DocumentSource::parse( - const intrusive_ptr expCtx, BSONObj stageObj) { + const intrusive_ptr& expCtx, BSONObj stageObj) { uassert(16435, "A pipeline stage specification object must contain exactly one field.", stageObj.nFields() == 1); diff --git a/src/mongo/db/pipeline/document_source.h b/src/mongo/db/pipeline/document_source.h index 81ba3fc61a6..bf4c65fa18b 100644 --- a/src/mongo/db/pipeline/document_source.h +++ b/src/mongo/db/pipeline/document_source.h @@ -268,23 +268,11 @@ public: virtual void reattachToOperationContext(OperationContext* opCtx) {} - /** - * Injects a new ExpressionContext into this DocumentSource and propagates the ExpressionContext - * to all child expressions, accumulators, etc. - * - * Stages which require work to propagate the ExpressionContext to their private execution - * machinery should override doInjectExpressionContext(). - */ - void injectExpressionContext(const boost::intrusive_ptr& expCtx) { - pExpCtx = expCtx; - doInjectExpressionContext(); - } - /** * Create a DocumentSource pipeline stage from 'stageObj'. */ static std::vector> parse( - const boost::intrusive_ptr expCtx, BSONObj stageObj); + const boost::intrusive_ptr& expCtx, BSONObj stageObj); /** * Registers a DocumentSource with a parsing function, so that when a stage with the given name @@ -443,15 +431,6 @@ protected: */ explicit DocumentSource(const boost::intrusive_ptr& pExpCtx); - /** - * DocumentSources which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - * - * Any stage subclassing from DocumentSource should override this method if it contains - * expressions or accumulators which need to attach to the newly injected ExpressionContext. - */ - virtual void doInjectExpressionContext() {} - /** * Attempt to perform an optimization with the following source in the pipeline. 'container' * refers to the entire pipeline, and 'itr' points to this stage within the pipeline. The caller diff --git a/src/mongo/db/pipeline/document_source_add_fields.cpp b/src/mongo/db/pipeline/document_source_add_fields.cpp index 0f7366538f9..445feaa1a51 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.cpp +++ b/src/mongo/db/pipeline/document_source_add_fields.cpp @@ -49,8 +49,7 @@ intrusive_ptr DocumentSourceAddFields::create( BSONObj addFieldsSpec, const intrusive_ptr& expCtx) { intrusive_ptr addFields( new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAddFields::create(addFieldsSpec), "$addFields")); - addFields->injectExpressionContext(expCtx); + expCtx, ParsedAddFields::create(expCtx, addFieldsSpec), "$addFields")); return addFields; } diff --git a/src/mongo/db/pipeline/document_source_add_fields.h b/src/mongo/db/pipeline/document_source_add_fields.h index b1ce0415ab3..77fe4d9d4e9 100644 --- a/src/mongo/db/pipeline/document_source_add_fields.h +++ b/src/mongo/db/pipeline/document_source_add_fields.h @@ -51,6 +51,8 @@ public: BSONElement elem, const boost::intrusive_ptr& expCtx); private: + // It is illegal to construct a DocumentSourceAddFields directly, use create() or + // createFromBson() instead. DocumentSourceAddFields() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket.cpp b/src/mongo/db/pipeline/document_source_bucket.cpp index 05152b7e313..7c671ea2998 100644 --- a/src/mongo/db/pipeline/document_source_bucket.cpp +++ b/src/mongo/db/pipeline/document_source_bucket.cpp @@ -43,9 +43,11 @@ REGISTER_MULTI_STAGE_ALIAS(bucket, DocumentSourceBucket::createFromBson); namespace { -intrusive_ptr getExpressionConstant(BSONElement expressionElem, - VariablesParseState vps) { - auto expr = Expression::parseOperand(expressionElem, vps)->optimize(); +intrusive_ptr getExpressionConstant( + const boost::intrusive_ptr& expCtx, + BSONElement expressionElem, + VariablesParseState vps) { + auto expr = Expression::parseOperand(expCtx, expressionElem, vps)->optimize(); return dynamic_cast(expr.get()); } } // namespace @@ -95,7 +97,7 @@ vector> DocumentSourceBucket::createFromBson( argument.type() == BSONType::Array); for (auto&& boundaryElem : argument.embeddedObject()) { - auto exprConst = getExpressionConstant(boundaryElem, vps); + auto exprConst = getExpressionConstant(pExpCtx, boundaryElem, vps); uassert(40191, str::stream() << "The $bucket 'boundaries' field must be an array of " "constant values, but found value: " @@ -144,7 +146,7 @@ vector> DocumentSourceBucket::createFromBson( } else if ("default" == argName) { // If there is a default, make sure that it parses to a constant expression then add // default to switch. - auto exprConst = getExpressionConstant(argument, vps); + auto exprConst = getExpressionConstant(pExpCtx, argument, vps); uassert(40195, str::stream() << "The $bucket 'default' field must be a constant expression, but found: " diff --git a/src/mongo/db/pipeline/document_source_bucket.h b/src/mongo/db/pipeline/document_source_bucket.h index cd7fe31b14e..aa83ee07cc2 100644 --- a/src/mongo/db/pipeline/document_source_bucket.h +++ b/src/mongo/db/pipeline/document_source_bucket.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr& pExpCtx); private: + // It is illegal to construct a DocumentSourceBucket directly, use createFromBson() instead. DocumentSourceBucket() = default; }; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.cpp b/src/mongo/db/pipeline/document_source_bucket_auto.cpp index c9a09354434..b02e8d0226f 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto.cpp @@ -196,7 +196,8 @@ void DocumentSourceBucketAuto::populateBuckets() { } // Initialize the current bucket. - Bucket currentBucket(currentValue.first, currentValue.first, _accumulatorFactories); + Bucket currentBucket( + pExpCtx, currentValue.first, currentValue.first, _accumulatorFactories); // Add the first value into the current bucket. addDocumentToBucket(currentValue, currentBucket); @@ -267,13 +268,14 @@ void DocumentSourceBucketAuto::populateBuckets() { } } -DocumentSourceBucketAuto::Bucket::Bucket(Value min, +DocumentSourceBucketAuto::Bucket::Bucket(const boost::intrusive_ptr& expCtx, + Value min, Value max, vector accumulatorFactories) : _min(min), _max(max) { _accums.reserve(accumulatorFactories.size()); for (auto&& factory : accumulatorFactories) { - _accums.push_back(factory()); + _accums.push_back(factory(expCtx)); } } @@ -344,7 +346,7 @@ Value DocumentSourceBucketAuto::serialize(bool explain) const { const size_t nOutputFields = _fieldNames.size(); MutableDocument outputSpec(nOutputFields); for (size_t i = 0; i < nOutputFields; i++) { - intrusive_ptr accum = _accumulatorFactories[i](); + intrusive_ptr accum = _accumulatorFactories[i](pExpCtx); outputSpec[_fieldNames[i]] = Value{Document{{accum->getOpName(), _expressions[i]->serialize(explain)}}}; } @@ -405,14 +407,16 @@ DocumentSourceBucketAuto::DocumentSourceBucketAuto( namespace { -boost::intrusive_ptr parseGroupByExpression(const BSONElement& groupByField, - const VariablesParseState& vps) { +boost::intrusive_ptr parseGroupByExpression( + const boost::intrusive_ptr& expCtx, + const BSONElement& groupByField, + const VariablesParseState& vps) { if (groupByField.type() == BSONType::Object && groupByField.embeddedObject().firstElementFieldName()[0] == '$') { - return Expression::parseObject(groupByField.embeddedObject(), vps); + return Expression::parseObject(expCtx, groupByField.embeddedObject(), vps); } else if (groupByField.type() == BSONType::String && groupByField.valueStringData()[0] == '$') { - return ExpressionFieldPath::parse(groupByField.str(), vps); + return ExpressionFieldPath::parse(expCtx, groupByField.str(), vps); } else { uasserted( 40239, @@ -440,7 +444,7 @@ intrusive_ptr DocumentSourceBucketAuto::createFromBson( for (auto&& argument : elem.Obj()) { const auto argName = argument.fieldNameStringData(); if ("groupBy" == argName) { - groupByExpression = parseGroupByExpression(argument, vps); + groupByExpression = parseGroupByExpression(pExpCtx, argument, vps); } else if ("buckets" == argName) { Value bucketsValue = Value(argument); @@ -467,7 +471,7 @@ intrusive_ptr DocumentSourceBucketAuto::createFromBson( for (auto&& outputField : argument.embeddedObject()) { accumulationStatements.push_back( - AccumulationStatement::parseAccumulationStatement(outputField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, outputField, vps)); } } else if ("granularity" == argName) { uassert(40261, @@ -475,7 +479,7 @@ intrusive_ptr DocumentSourceBucketAuto::createFromBson( << "The $bucketAuto 'granularity' field must be a string, but found type: " << typeName(argument.type()), argument.type() == BSONType::String); - granularityRounder = GranularityRounder::getGranularityRounder(argument.str()); + granularityRounder = GranularityRounder::getGranularityRounder(pExpCtx, argument.str()); } else { uasserted(40245, str::stream() << "Unrecognized option to $bucketAuto: " << argName); } diff --git a/src/mongo/db/pipeline/document_source_bucket_auto.h b/src/mongo/db/pipeline/document_source_bucket_auto.h index 9279bcd52b9..61f7da30c71 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto.h +++ b/src/mongo/db/pipeline/document_source_bucket_auto.h @@ -93,7 +93,10 @@ private: // struct for holding information about a bucket. struct Bucket { - Bucket(Value min, Value max, std::vector accumulatorFactories); + Bucket(const boost::intrusive_ptr& expCtx, + Value min, + Value max, + std::vector accumulatorFactories); Value _min; Value _max; std::vector> _accums; diff --git a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp index 530948d2d3d..2d71b676024 100644 --- a/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_auto_test.cpp @@ -355,7 +355,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToCorrectlySpillToDisk) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -397,7 +397,7 @@ TEST_F(BucketAutoTests, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -643,7 +643,7 @@ void assertCannotSpillToDisk(const boost::intrusive_ptr& expC VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, @@ -685,7 +685,7 @@ TEST_F(BucketAutoTests, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto groupByExpression = ExpressionFieldPath::parse("$a", vps); + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$a", vps); const int numBuckets = 2; auto bucketAutoStage = DocumentSourceBucketAuto::create(expCtx, diff --git a/src/mongo/db/pipeline/document_source_bucket_test.cpp b/src/mongo/db/pipeline/document_source_bucket_test.cpp index b93916f4347..bdfb3635bb7 100644 --- a/src/mongo/db/pipeline/document_source_bucket_test.cpp +++ b/src/mongo/db/pipeline/document_source_bucket_test.cpp @@ -156,9 +156,6 @@ class InvalidBucketSpec : public AggregationContextFixture { public: vector> createBucket(BSONObj bucketSpec) { auto sources = DocumentSourceBucket::createFromBson(bucketSpec.firstElement(), getExpCtx()); - for (auto&& source : sources) { - source->injectExpressionContext(getExpCtx()); - } return sources; } }; diff --git a/src/mongo/db/pipeline/document_source_count.h b/src/mongo/db/pipeline/document_source_count.h index e1817fbfca3..50fbccdf41b 100644 --- a/src/mongo/db/pipeline/document_source_count.h +++ b/src/mongo/db/pipeline/document_source_count.h @@ -44,6 +44,7 @@ public: BSONElement elem, const boost::intrusive_ptr& pExpCtx); private: + // It is illegal to construct a DocumentSourceCount directly, use createFromBson() instead. DocumentSourceCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_cursor.cpp b/src/mongo/db/pipeline/document_source_cursor.cpp index 5c5be2aef1f..2def98d25bd 100644 --- a/src/mongo/db/pipeline/document_source_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_cursor.cpp @@ -221,12 +221,6 @@ Value DocumentSourceCursor::serialize(bool explain) const { return Value(DOC(getSourceName() << out.freezeToValue())); } -void DocumentSourceCursor::doInjectExpressionContext() { - if (_limit) { - _limit->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceCursor::detachFromOperationContext() { if (_exec) { _exec->detachFromOperationContext(); @@ -259,7 +253,6 @@ intrusive_ptr DocumentSourceCursor::create( const intrusive_ptr& pExpCtx) { intrusive_ptr source( new DocumentSourceCursor(ns, std::move(exec), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_cursor.h b/src/mongo/db/pipeline/document_source_cursor.h index ebbf9eff64d..799834f1487 100644 --- a/src/mongo/db/pipeline/document_source_cursor.h +++ b/src/mongo/db/pipeline/document_source_cursor.h @@ -132,9 +132,6 @@ public: const PlanSummaryStats& getPlanSummaryStats() const; -protected: - void doInjectExpressionContext() final; - private: DocumentSourceCursor(const std::string& ns, std::unique_ptr exec, diff --git a/src/mongo/db/pipeline/document_source_facet.cpp b/src/mongo/db/pipeline/document_source_facet.cpp index 7b8cd82de0f..9b03b233051 100644 --- a/src/mongo/db/pipeline/document_source_facet.cpp +++ b/src/mongo/db/pipeline/document_source_facet.cpp @@ -203,12 +203,6 @@ intrusive_ptr DocumentSourceFacet::optimize() { return this; } -void DocumentSourceFacet::doInjectExpressionContext() { - for (auto&& facet : _facets) { - facet.pipeline->injectExpressionContext(pExpCtx); - } -} - void DocumentSourceFacet::doInjectMongodInterface(std::shared_ptr mongod) { for (auto&& facet : _facets) { for (auto&& stage : facet.pipeline->getSources()) { diff --git a/src/mongo/db/pipeline/document_source_facet.h b/src/mongo/db/pipeline/document_source_facet.h index b0a9a0867ee..921b2e87609 100644 --- a/src/mongo/db/pipeline/document_source_facet.h +++ b/src/mongo/db/pipeline/document_source_facet.h @@ -43,7 +43,7 @@ namespace mongo { class BSONElement; class TeeBuffer; class DocumentSourceTeeConsumer; -struct ExpressionContext; +class ExpressionContext; class NamespaceString; /** @@ -97,11 +97,6 @@ public: */ boost::intrusive_ptr optimize() final; - /** - * Injects the expression context into inner pipelines. - */ - void doInjectExpressionContext() final; - /** * Takes a union of all sub-pipelines, and adds them to 'deps'. */ diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp index 8c4a2ca3ff5..17e9997bbf5 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near.cpp @@ -178,7 +178,6 @@ void DocumentSourceGeoNear::runCommand() { intrusive_ptr DocumentSourceGeoNear::create( const intrusive_ptr& pCtx) { intrusive_ptr source(new DocumentSourceGeoNear(pCtx)); - source->injectExpressionContext(pCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.cpp b/src/mongo/db/pipeline/document_source_graph_lookup.cpp index 7743e9c6058..03515560057 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup.cpp @@ -164,7 +164,7 @@ DocumentSource::GetNextResult DocumentSourceGraphLookUp::getNextUnwound() { void DocumentSourceGraphLookUp::dispose() { _cache.clear(); - _frontier->clear(); + _frontier.clear(); _visited.clear(); pSource->dispose(); } @@ -180,7 +180,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { auto matchStage = makeMatchStageFromFrontier(&cached); ValueUnorderedSet queried = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _frontier->swap(queried); + _frontier.swap(queried); _frontierUsageBytes = 0; // Process cached values, populating '_frontier' for the next iteration of search. @@ -219,7 +219,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { } while (shouldPerformAnotherQuery && depth < std::numeric_limits::max() && (!_maxDepth || depth <= *_maxDepth)); - _frontier->clear(); + _frontier.clear(); _frontierUsageBytes = 0; } @@ -264,12 +264,12 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon Value recurseOn = Value(elem); if (recurseOn.isArray()) { for (auto&& subElem : recurseOn.getArray()) { - _frontier->insert(subElem); + _frontier.insert(subElem); _frontierUsageBytes += subElem.getApproximateSize(); } } else if (!recurseOn.missing()) { // Don't recurse on a missing value. - _frontier->insert(recurseOn); + _frontier.insert(recurseOn); _frontierUsageBytes += recurseOn.getApproximateSize(); } } @@ -304,13 +304,13 @@ void DocumentSourceGraphLookUp::addToCache(const BSONObj& result, boost::optional DocumentSourceGraphLookUp::makeMatchStageFromFrontier(BSONObjSet* cached) { // Add any cached values to 'cached' and remove them from '_frontier'. - for (auto it = _frontier->begin(); it != _frontier->end();) { + for (auto it = _frontier.begin(); it != _frontier.end();) { if (auto entry = _cache[*it]) { for (auto&& obj : *entry) { cached->insert(obj); } size_t valueSize = it->getApproximateSize(); - it = _frontier->erase(it); + it = _frontier.erase(it); // If the cached value increased in size while in the cache, we don't want to underflow // '_frontierUsageBytes'. @@ -340,7 +340,7 @@ boost::optional DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B BSONObjBuilder subObj(connectToObj.subobjStart(_connectToField.fullPath())); { BSONArrayBuilder in(subObj.subarrayStart("$in")); - for (auto&& value : *_frontier) { + for (auto&& value : _frontier) { in << value; } } @@ -349,7 +349,7 @@ boost::optional DocumentSourceGraphLookUp::makeMatchStageFromFrontier(B } } - return _frontier->empty() ? boost::none : boost::optional(match.obj()); + return _frontier.empty() ? boost::none : boost::optional(match.obj()); } void DocumentSourceGraphLookUp::performSearch() { @@ -363,11 +363,11 @@ void DocumentSourceGraphLookUp::performSearch() { // If _startWith evaluates to an array, treat each value as a separate starting point. if (startingValue.isArray()) { for (auto value : startingValue.getArray()) { - _frontier->insert(value); + _frontier.insert(value); _frontierUsageBytes += value.getApproximateSize(); } } else { - _frontier->insert(startingValue); + _frontier.insert(startingValue); _frontierUsageBytes += startingValue.getApproximateSize(); } @@ -460,24 +460,6 @@ void DocumentSourceGraphLookUp::serializeToArray(std::vector& array, bool } } -void DocumentSourceGraphLookUp::doInjectExpressionContext() { - _startWith->injectExpressionContext(pExpCtx); - - auto it = pExpCtx->resolvedNamespaces.find(_from.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); - - _frontier = pExpCtx->getValueComparator().makeUnorderedValueSet(); - _cache.setValueComparator(pExpCtx->getValueComparator()); -} - void DocumentSourceGraphLookUp::doDetachFromOperationContext() { _fromExpCtx->opCtx = nullptr; } @@ -506,9 +488,19 @@ DocumentSourceGraphLookUp::DocumentSourceGraphLookUp( _additionalFilter(additionalFilter), _depthField(depthField), _maxDepth(maxDepth), + _frontier(pExpCtx->getValueComparator().makeUnorderedValueSet()), _visited(ValueComparator::kInstance.makeUnorderedValueMap()), - _cache(expCtx->getValueComparator()), - _unwind(unwindSrc) {} + _cache(pExpCtx->getValueComparator()), + _unwind(unwindSrc) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_from); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} intrusive_ptr DocumentSourceGraphLookUp::create( const intrusive_ptr& expCtx, @@ -533,8 +525,6 @@ intrusive_ptr DocumentSourceGraphLookUp::create( maxDepth, unwindSrc)); source->_variables.reset(new Variables()); - - source->injectExpressionContext(expCtx); return source; } @@ -556,7 +546,7 @@ intrusive_ptr DocumentSourceGraphLookUp::createFromBson( const auto argName = argument.fieldNameStringData(); if (argName == "startWith") { - startWith = Expression::parseOperand(argument, vps); + startWith = Expression::parseOperand(expCtx, argument, vps); continue; } else if (argName == "maxDepth") { uassert(40100, diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.h b/src/mongo/db/pipeline/document_source_graph_lookup.h index 01473bf44fc..be12637d404 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.h +++ b/src/mongo/db/pipeline/document_source_graph_lookup.h @@ -94,9 +94,6 @@ public: static boost::intrusive_ptr createFromBson( BSONElement elem, const boost::intrusive_ptr& pExpCtx); -protected: - void doInjectExpressionContext() final; - private: DocumentSourceGraphLookUp( const boost::intrusive_ptr& expCtx, @@ -188,9 +185,7 @@ private: size_t _frontierUsageBytes = 0; // Only used during the breadth-first search, tracks the set of values on the current frontier. - // We use boost::optional to defer initialization until the ExpressionContext containing the - // correct comparator is injected. - boost::optional _frontier; + ValueUnorderedSet _frontier; // Tracks nodes that have been discovered for a given input. Keys are the '_id' value of the // document from the foreign collection, value is the document itself. The keys are compared diff --git a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp index 6350a8c39e1..53e3e6ea1ac 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup_test.cpp @@ -70,7 +70,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_results)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -90,17 +89,18 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared(std::move(fromContents))); @@ -119,17 +119,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"to", 1}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared(std::move(fromContents))); @@ -147,18 +148,19 @@ TEST_F(DocumentSourceGraphLookUpTest, std::deque fromContents{Document{{"to", 0}}}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, boost::none); - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - unwindStage); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + unwindStage); graphLookupStage->injectMongodInterface( std::make_shared(std::move(fromContents))); graphLookupStage->setSource(inputMock.get()); @@ -190,17 +192,18 @@ TEST_F(DocumentSourceGraphLookUpTest, Document(to1), Document(to2), Document(to0from1), Document(to0from2)}; NamespaceString fromNs("test", "graph_lookup"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; - auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, - fromNs, - "results", - "from", - "to", - ExpressionFieldPath::create("_id"), - boost::none, - boost::none, - boost::none, - boost::none); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); + auto graphLookupStage = + DocumentSourceGraphLookUp::create(expCtx, + fromNs, + "results", + "from", + "to", + ExpressionFieldPath::create(expCtx, "_id"), + boost::none, + boost::none, + boost::none, + boost::none); graphLookupStage->setSource(inputMock.get()); graphLookupStage->injectMongodInterface( std::make_shared(std::move(fromContents))); @@ -254,14 +257,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePauses) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -322,7 +325,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { Document{{"_id", "a"_sd}, {"to", 0}, {"from", 1}}, Document{{"_id", "b"_sd}, {"to", 1}}}; NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); const bool preserveNullAndEmptyArrays = false; const boost::optional includeArrayIndex = boost::none; @@ -335,7 +338,7 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -387,14 +390,14 @@ TEST_F(DocumentSourceGraphLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); auto graphLookupStage = DocumentSourceGraphLookUp::create(expCtx, fromNs, "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -409,7 +412,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportAsFieldIsModified) TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); auto unwindStage = DocumentSourceUnwind::create(expCtx, "results", false, std::string("arrIndex")); auto graphLookupStage = @@ -418,7 +421,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupShouldReportFieldsModifiedByAbs "results", "from", "to", - ExpressionFieldPath::create("startPoint"), + ExpressionFieldPath::create(expCtx, "startPoint"), boost::none, boost::none, boost::none, @@ -437,7 +440,7 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar auto inputMock = DocumentSourceMock::create(Document({{"_id", 0}, {"a", 1}, {"b", 2}})); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); std::deque fromContents{Document{{"_id", 0}, {"to", true}}, Document{{"_id", 1}, {"to", false}}}; @@ -447,9 +450,10 @@ TEST_F(DocumentSourceGraphLookUpTest, GraphLookupWithComparisonExpressionForStar "results", "from", "to", - ExpressionCompare::create(ExpressionCompare::GT, - ExpressionFieldPath::create("a"), - ExpressionFieldPath::create("b")), + ExpressionCompare::create(expCtx, + ExpressionCompare::GT, + ExpressionFieldPath::create(expCtx, "a"), + ExpressionFieldPath::create(expCtx, "b")), boost::none, boost::none, boost::none, diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp index cf91c8eff5f..a70a65c06b9 100644 --- a/src/mongo/db/pipeline/document_source_group.cpp +++ b/src/mongo/db/pipeline/document_source_group.cpp @@ -198,23 +198,6 @@ intrusive_ptr DocumentSourceGroup::optimize() { return this; } -void DocumentSourceGroup::doInjectExpressionContext() { - // Groups map must respect new comparator. - _groups = pExpCtx->getValueComparator().makeUnorderedValueMap(); - - for (auto&& idExpr : _idExpressions) { - idExpr->injectExpressionContext(pExpCtx); - } - - for (auto&& expr : vpExpression) { - expr->injectExpressionContext(pExpCtx); - } - - for (auto&& accum : _currentAccumulators) { - accum->injectExpressionContext(pExpCtx); - } -} - Value DocumentSourceGroup::serialize(bool explain) const { MutableDocument insides; @@ -235,7 +218,7 @@ Value DocumentSourceGroup::serialize(bool explain) const { // Add the remaining fields. const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { - intrusive_ptr accum = vpAccumulatorFactory[i](); + intrusive_ptr accum = vpAccumulatorFactory[i](pExpCtx); insides[vFieldName[i]] = Value(DOC(accum->getOpName() << vpExpression[i]->serialize(explain))); } @@ -280,7 +263,6 @@ intrusive_ptr DocumentSourceGroup::create( groupStage->addAccumulator(statement); } groupStage->_variables = stdx::make_unique(numVariables); - groupStage->injectExpressionContext(pExpCtx); return groupStage; } @@ -292,6 +274,7 @@ DocumentSourceGroup::DocumentSourceGroup(const intrusive_ptr& _inputSort(BSONObj()), _streaming(false), _initialized(false), + _groups(pExpCtx->getValueComparator().makeUnorderedValueMap()), _spilled(false), _extSortAllowed(pExpCtx->extSortAllowed && !pExpCtx->inRouter) {} @@ -303,7 +286,7 @@ void DocumentSourceGroup::addAccumulator(AccumulationStatement accumulationState namespace { -intrusive_ptr parseIdExpression(const intrusive_ptr expCtx, +intrusive_ptr parseIdExpression(const intrusive_ptr& expCtx, BSONElement groupField, const VariablesParseState& vps) { if (groupField.type() == Object && !groupField.Obj().isEmpty()) { @@ -312,18 +295,18 @@ intrusive_ptr parseIdExpression(const intrusive_ptr DocumentSourceGroup::createFromBson( } else { // Any other field will be treated as an accumulator specification. pGroup->addAccumulator( - AccumulationStatement::parseAccumulationStatement(groupField, vps)); + AccumulationStatement::parseAccumulationStatement(pExpCtx, groupField, vps)); } } @@ -493,8 +476,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Set up accumulators. _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } // We only need to load the first document. @@ -543,8 +525,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // Add the accumulators group.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - group.push_back(vpAccumulatorFactory[i]()); - group.back()->injectExpressionContext(pExpCtx); + group.push_back(vpAccumulatorFactory[i](pExpCtx)); } } else { for (size_t i = 0; i < numAccumulators; i++) { @@ -599,8 +580,7 @@ DocumentSource::GetNextResult DocumentSourceGroup::initialize() { // prepare current to accumulate data _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { - _currentAccumulators.push_back(vpAccumulatorFactory[i]()); - _currentAccumulators.back()->injectExpressionContext(pExpCtx); + _currentAccumulators.push_back(vpAccumulatorFactory[i](pExpCtx)); } verify(_sorterIterator->more()); // we put data in, we should get something out. @@ -876,7 +856,7 @@ intrusive_ptr DocumentSourceGroup::getMergeSource() { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); /* the merger will use the same grouping key */ - pMerger->setIdExpression(ExpressionFieldPath::parse("$$ROOT._id", vps)); + pMerger->setIdExpression(ExpressionFieldPath::parse(pExpCtx, "$$ROOT._id", vps)); const size_t n = vFieldName.size(); for (size_t i = 0; i < n; ++i) { @@ -888,13 +868,13 @@ intrusive_ptr DocumentSourceGroup::getMergeSource() { expression or constant. Here, we accumulate the output of the same name from the prior group. */ - pMerger->addAccumulator({vFieldName[i], - vpAccumulatorFactory[i], - ExpressionFieldPath::parse("$$ROOT." + vFieldName[i], vps)}); + pMerger->addAccumulator( + {vFieldName[i], + vpAccumulatorFactory[i], + ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + vFieldName[i], vps)}); } pMerger->_variables.reset(new Variables(idGenerator.getIdCount())); - pMerger->injectExpressionContext(pExpCtx); return pMerger; } diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h index 76d9282ebb9..e79ce631817 100644 --- a/src/mongo/db/pipeline/document_source_group.h +++ b/src/mongo/db/pipeline/document_source_group.h @@ -96,9 +96,6 @@ public: boost::intrusive_ptr getShardSource() final; boost::intrusive_ptr getMergeSource() final; -protected: - void doInjectExpressionContext() final; - private: explicit DocumentSourceGroup(const boost::intrusive_ptr& pExpCtx, size_t maxMemoryUsageBytes = kDefaultMaxMemoryUsageBytes); diff --git a/src/mongo/db/pipeline/document_source_group_test.cpp b/src/mongo/db/pipeline/document_source_group_test.cpp index 599e282e6b9..25b10ddcb5b 100644 --- a/src/mongo/db/pipeline/document_source_group_test.cpp +++ b/src/mongo/db/pipeline/document_source_group_test.cpp @@ -45,7 +45,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/dbtests/dbtests.h" @@ -110,8 +110,8 @@ TEST_F(DocumentSourceGroupTest, ShouldBeAbleToPauseLoadingWhileSpilled) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -150,8 +150,8 @@ TEST_F(DocumentSourceGroupTest, ShouldErrorIfNotAllowedToSpillToDiskAndResultSet VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -173,8 +173,8 @@ TEST_F(DocumentSourceGroupTest, ShouldCorrectlyTrackMemoryUsageBetweenPauses) { VariablesParseState vps(&idGen); AccumulationStatement pushStatement{"spaceHog", AccumulationStatement::getFactory("$push"), - ExpressionFieldPath::parse("$largeStr", vps)}; - auto groupByExpression = ExpressionFieldPath::parse("$_id", vps); + ExpressionFieldPath::parse(expCtx, "$largeStr", vps)}; + auto groupByExpression = ExpressionFieldPath::parse(expCtx, "$_id", vps); auto group = DocumentSourceGroup::create( expCtx, groupByExpression, {pushStatement}, idGen.getIdCount(), maxMemoryUsageBytes); @@ -204,7 +204,8 @@ public: Base() : _queryServiceContext(stdx::make_unique()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))), + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))), _tempDir("DocumentSourceGroupTest") {} protected: @@ -212,15 +213,14 @@ protected: BSONObj namedSpec = BSON("$group" << spec); BSONElement specElement = namedSpec.firstElement(); - intrusive_ptr expressionContext = - new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); + intrusive_ptr expressionContext = + new ExpressionContextForTest(_opCtx.get(), AggregationRequest(NamespaceString(ns), {})); expressionContext->inShard = inShard; expressionContext->inRouter = inRouter; // Won't spill to disk properly if it needs to. expressionContext->tempDir = _tempDir.path(); _group = DocumentSourceGroup::createFromBson(specElement, expressionContext); - _group->injectExpressionContext(expressionContext); assertRoundTrips(_group); } DocumentSourceGroup* group() { @@ -234,7 +234,7 @@ protected: ASSERT(source->getNext().isEOF()); } - intrusive_ptr ctx() const { + intrusive_ptr ctx() const { return _ctx; } @@ -251,7 +251,7 @@ private: } std::unique_ptr _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr _ctx; + intrusive_ptr _ctx; intrusive_ptr _group; TempDir _tempDir; }; diff --git a/src/mongo/db/pipeline/document_source_limit.cpp b/src/mongo/db/pipeline/document_source_limit.cpp index 634d635b22d..797c6ed6652 100644 --- a/src/mongo/db/pipeline/document_source_limit.cpp +++ b/src/mongo/db/pipeline/document_source_limit.cpp @@ -93,7 +93,6 @@ intrusive_ptr DocumentSourceLimit::create( const intrusive_ptr& pExpCtx, long long limit) { uassert(15958, "the limit must be positive", limit > 0); intrusive_ptr source(new DocumentSourceLimit(pExpCtx, limit)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_lookup.cpp b/src/mongo/db/pipeline/document_source_lookup.cpp index 5371542505c..a0923db276e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_lookup.cpp @@ -54,7 +54,16 @@ DocumentSourceLookUp::DocumentSourceLookUp(NamespaceString fromNs, _as(std::move(as)), _localField(std::move(localField)), _foreignField(foreignField), - _foreignFieldFieldName(std::move(foreignField)) {} + _foreignFieldFieldName(std::move(foreignField)) { + const auto& resolvedNamespace = pExpCtx->getResolvedNamespace(_fromNs); + _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); + _fromPipeline = resolvedNamespace.pipeline; + + // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage + // we'll eventually construct from the input document. + _fromPipeline.reserve(_fromPipeline.size() + 1); + _fromPipeline.push_back(BSONObj()); +} std::unique_ptr DocumentSourceLookUp::liteParse( const AggregationRequest& request, const BSONElement& spec) { @@ -454,19 +463,6 @@ DocumentSource::GetDepsReturn DocumentSourceLookUp::getDependencies(DepsTracker* return SEE_NEXT; } -void DocumentSourceLookUp::doInjectExpressionContext() { - auto it = pExpCtx->resolvedNamespaces.find(_fromNs.coll()); - invariant(it != pExpCtx->resolvedNamespaces.end()); - const auto& resolvedNamespace = it->second; - _fromExpCtx = pExpCtx->copyWith(resolvedNamespace.ns); - _fromPipeline = resolvedNamespace.pipeline; - - // We append an additional BSONObj to '_fromPipeline' as a placeholder for the $match stage - // we'll eventually construct from the input document. - _fromPipeline.reserve(_fromPipeline.size() + 1); - _fromPipeline.push_back(BSONObj()); -} - void DocumentSourceLookUp::doDetachFromOperationContext() { if (_pipeline) { // We have a pipeline we're going to be executing across multiple calls to getNext(), so we diff --git a/src/mongo/db/pipeline/document_source_lookup.h b/src/mongo/db/pipeline/document_source_lookup.h index 13118828b65..e179abfe73e 100644 --- a/src/mongo/db/pipeline/document_source_lookup.h +++ b/src/mongo/db/pipeline/document_source_lookup.h @@ -113,9 +113,6 @@ public: _handlingUnwind = true; } -protected: - void doInjectExpressionContext() final; - private: DocumentSourceLookUp(NamespaceString fromNs, std::string as, diff --git a/src/mongo/db/pipeline/document_source_lookup_test.cpp b/src/mongo/db/pipeline/document_source_lookup_test.cpp index 0f54e3741c4..9da414ec6f6 100644 --- a/src/mongo/db/pipeline/document_source_lookup_test.cpp +++ b/src/mongo/db/pipeline/document_source_lookup_test.cpp @@ -54,6 +54,10 @@ using std::vector; using DocumentSourceLookUpTest = AggregationContextFixture; TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); + intrusive_ptr source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -63,7 +67,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { {"as", "d.e"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -73,6 +77,10 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnAsField) { } TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { + auto expCtx = getExpCtx(); + NamespaceString fromNs("test", "a"); + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); + intrusive_ptr source = DocumentSourceMock::create(); source->sorts = {BSON("a" << 1 << "d.e" << 1 << "c" << 1)}; auto lookup = DocumentSourceLookUp::createFromBson(Document{{"$lookup", @@ -82,7 +90,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldTruncateOutputSortOnSuffixOfAsField) { {"as", "d"_sd}}}} .toBson() .firstElement(), - getExpCtx()); + expCtx); lookup->setSource(source.get()); BSONObjSet outputSort = lookup->getOutputSorts(); @@ -158,7 +166,6 @@ public: } pipeline.getValue()->addInitialSource(DocumentSourceMock::create(_mockResults)); - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); return pipeline; @@ -171,7 +178,7 @@ private: TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -191,7 +198,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); // Mock out the foreign collection. deque mockForeignContents{Document{{"_id", 0}}, @@ -222,7 +228,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePauses) { TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -247,8 +253,6 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { DocumentSource::GetNextResult::makePauseExecution()}); lookup->setSource(mockLocalSource.get()); - lookup->injectExpressionContext(expCtx); - // Mock out the foreign collection. deque mockForeignContents{Document{{"_id", 0}}, Document{{"_id", 1}}}; @@ -276,7 +280,7 @@ TEST_F(DocumentSourceLookUpTest, ShouldPropagatePausesWhileUnwinding) { TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", @@ -297,7 +301,7 @@ TEST_F(DocumentSourceLookUpTest, LookupReportsAsFieldIsModified) { TEST_F(DocumentSourceLookUpTest, LookupReportsFieldsModifiedByAbsorbedUnwind) { auto expCtx = getExpCtx(); NamespaceString fromNs("test", "foreign"); - expCtx->resolvedNamespaces[fromNs.coll()] = {fromNs, std::vector{}}; + expCtx->setResolvedNamespace(fromNs, {fromNs, std::vector{}}); // Set up the $lookup stage. auto lookupSpec = Document{{"$lookup", diff --git a/src/mongo/db/pipeline/document_source_match.cpp b/src/mongo/db/pipeline/document_source_match.cpp index 18cb22db71d..494d6589099 100644 --- a/src/mongo/db/pipeline/document_source_match.cpp +++ b/src/mongo/db/pipeline/document_source_match.cpp @@ -401,7 +401,7 @@ DocumentSourceMatch::splitSourceBy(const std::set& fields) { boost::intrusive_ptr DocumentSourceMatch::descendMatchOnPath( MatchExpression* matchExpr, const std::string& descendOn, - intrusive_ptr expCtx) { + const intrusive_ptr& expCtx) { expression::mapOver(matchExpr, [&descendOn](MatchExpression* node, std::string path) -> void { // Cannot call this method on a $match including a $elemMatch. invariant(node->matchType() != MatchExpression::ELEM_MATCH_OBJECT && @@ -453,7 +453,6 @@ intrusive_ptr DocumentSourceMatch::create( BSONObj filter, const intrusive_ptr& expCtx) { uassertNoDisallowedClauses(filter); intrusive_ptr match(new DocumentSourceMatch(filter, expCtx)); - match->injectExpressionContext(expCtx); return match; } @@ -490,10 +489,6 @@ void DocumentSourceMatch::addDependencies(DepsTracker* deps) const { }); } -void DocumentSourceMatch::doInjectExpressionContext() { - _expression->setCollator(pExpCtx->getCollator()); -} - DocumentSourceMatch::DocumentSourceMatch(const BSONObj& query, const intrusive_ptr& pExpCtx) : DocumentSource(pExpCtx), _predicate(query.getOwned()), _isTextQuery(isTextQuery(query)) { diff --git a/src/mongo/db/pipeline/document_source_match.h b/src/mongo/db/pipeline/document_source_match.h index e7f3d4227d9..badfe1ed8b9 100644 --- a/src/mongo/db/pipeline/document_source_match.h +++ b/src/mongo/db/pipeline/document_source_match.h @@ -140,9 +140,7 @@ public: static boost::intrusive_ptr descendMatchOnPath( MatchExpression* matchExpr, const std::string& path, - boost::intrusive_ptr expCtx); - - void doInjectExpressionContext(); + const boost::intrusive_ptr& expCtx); private: DocumentSourceMatch(const BSONObj& query, diff --git a/src/mongo/db/pipeline/document_source_merge_cursors.cpp b/src/mongo/db/pipeline/document_source_merge_cursors.cpp index 866259c663e..b03db6560df 100644 --- a/src/mongo/db/pipeline/document_source_merge_cursors.cpp +++ b/src/mongo/db/pipeline/document_source_merge_cursors.cpp @@ -57,7 +57,6 @@ intrusive_ptr DocumentSourceMergeCursors::create( const intrusive_ptr& pExpCtx) { intrusive_ptr source( new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx)); - source->injectExpressionContext(pExpCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_mock.cpp b/src/mongo/db/pipeline/document_source_mock.cpp index 686f4ef6f2e..3918bbd5ad7 100644 --- a/src/mongo/db/pipeline/document_source_mock.cpp +++ b/src/mongo/db/pipeline/document_source_mock.cpp @@ -31,6 +31,8 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" namespace mongo { @@ -38,7 +40,7 @@ using boost::intrusive_ptr; using std::deque; DocumentSourceMock::DocumentSourceMock(deque results) - : DocumentSource(nullptr), + : DocumentSource(new ExpressionContextForTest()), queue(std::move(results)), sorts(SimpleBSONObjComparator::kInstance.makeBSONObjSet()) {} diff --git a/src/mongo/db/pipeline/document_source_mock.h b/src/mongo/db/pipeline/document_source_mock.h index 7a600483661..7b966f5c418 100644 --- a/src/mongo/db/pipeline/document_source_mock.h +++ b/src/mongo/db/pipeline/document_source_mock.h @@ -79,10 +79,6 @@ public: return this; } - void doInjectExpressionContext() override { - isExpCtxInjected = true; - } - // Return documents from front of queue. std::deque queue; diff --git a/src/mongo/db/pipeline/document_source_project.cpp b/src/mongo/db/pipeline/document_source_project.cpp index fd44d534dde..a0ee096f36e 100644 --- a/src/mongo/db/pipeline/document_source_project.cpp +++ b/src/mongo/db/pipeline/document_source_project.cpp @@ -49,8 +49,7 @@ REGISTER_DOCUMENT_SOURCE(project, intrusive_ptr DocumentSourceProject::create( BSONObj projectSpec, const intrusive_ptr& expCtx) { intrusive_ptr project(new DocumentSourceSingleDocumentTransformation( - expCtx, ParsedAggregationProjection::create(projectSpec), "$project")); - project->injectExpressionContext(expCtx); + expCtx, ParsedAggregationProjection::create(expCtx, projectSpec), "$project")); return project; } diff --git a/src/mongo/db/pipeline/document_source_project.h b/src/mongo/db/pipeline/document_source_project.h index 869dce512b1..4ed9db432d8 100644 --- a/src/mongo/db/pipeline/document_source_project.h +++ b/src/mongo/db/pipeline/document_source_project.h @@ -53,6 +53,8 @@ public: BSONElement elem, const boost::intrusive_ptr& pExpCtx); private: + // It is illegal to construct a DocumentSourceProject directly, use create() or createFromBson() + // instead. DocumentSourceProject() = default; }; diff --git a/src/mongo/db/pipeline/document_source_redact.cpp b/src/mongo/db/pipeline/document_source_redact.cpp index 86b1c00960d..0096017a2f8 100644 --- a/src/mongo/db/pipeline/document_source_redact.cpp +++ b/src/mongo/db/pipeline/document_source_redact.cpp @@ -163,10 +163,6 @@ intrusive_ptr DocumentSourceRedact::optimize() { return this; } -void DocumentSourceRedact::doInjectExpressionContext() { - _expression->injectExpressionContext(pExpCtx); -} - Value DocumentSourceRedact::serialize(bool explain) const { return Value(DOC(getSourceName() << _expression.get()->serialize(explain))); } @@ -179,7 +175,7 @@ intrusive_ptr DocumentSourceRedact::createFromBson( Variables::Id decendId = vps.defineVariable("DESCEND"); Variables::Id pruneId = vps.defineVariable("PRUNE"); Variables::Id keepId = vps.defineVariable("KEEP"); - intrusive_ptr expression = Expression::parseOperand(elem, vps); + intrusive_ptr expression = Expression::parseOperand(expCtx, elem, vps); intrusive_ptr source = new DocumentSourceRedact(expCtx, expression); // TODO figure out how much of this belongs in constructor and how much here. diff --git a/src/mongo/db/pipeline/document_source_redact.h b/src/mongo/db/pipeline/document_source_redact.h index f056411f765..b9cf6242d42 100644 --- a/src/mongo/db/pipeline/document_source_redact.h +++ b/src/mongo/db/pipeline/document_source_redact.h @@ -47,8 +47,6 @@ public: Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr createFromBson( BSONElement elem, const boost::intrusive_ptr& expCtx); diff --git a/src/mongo/db/pipeline/document_source_replace_root.cpp b/src/mongo/db/pipeline/document_source_replace_root.cpp index 6f1844ed281..1ae82d5cf3a 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.cpp +++ b/src/mongo/db/pipeline/document_source_replace_root.cpp @@ -87,17 +87,14 @@ public: return DocumentSource::EXHAUSTIVE_FIELDS; } - void injectExpressionContext(const boost::intrusive_ptr& pExpCtx) final { - _newRoot->injectExpressionContext(pExpCtx); - } - DocumentSource::GetModPathsReturn getModifiedPaths() const final { // Replaces the entire root, so all paths are modified. return {DocumentSource::GetModPathsReturn::Type::kAllPaths, std::set{}}; } // Create the replaceRoot transformer. Uasserts on invalid input. - static std::unique_ptr create(const BSONElement& spec) { + static std::unique_ptr create( + const boost::intrusive_ptr& expCtx, const BSONElement& spec) { // Confirm that the stage was called with an object. uassert(40229, @@ -108,12 +105,12 @@ public: // Create the pointer, parse the stage, and return. std::unique_ptr parsedReplaceRoot = stdx::make_unique(); - parsedReplaceRoot->parse(spec); + parsedReplaceRoot->parse(expCtx, spec); return parsedReplaceRoot; } // Check for valid replaceRoot options, and populate internal state variables. - void parse(const BSONElement& spec) { + void parse(const boost::intrusive_ptr& expCtx, const BSONElement& spec) { // We need a VariablesParseState in order to parse the 'newRoot' expression. VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); @@ -124,7 +121,7 @@ public: if (argName == "newRoot") { // Allows for field path, object, and other expressions. - _newRoot = Expression::parseOperand(argument, vps); + _newRoot = Expression::parseOperand(expCtx, argument, vps); } else { uasserted(40230, str::stream() << "unrecognized option to $replaceRoot stage: " << argName @@ -150,7 +147,7 @@ intrusive_ptr DocumentSourceReplaceRoot::createFromBson( BSONElement elem, const intrusive_ptr& pExpCtx) { return new DocumentSourceSingleDocumentTransformation( - pExpCtx, ReplaceRootTransformation::create(elem), "$replaceRoot"); + pExpCtx, ReplaceRootTransformation::create(pExpCtx, elem), "$replaceRoot"); } } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_replace_root.h b/src/mongo/db/pipeline/document_source_replace_root.h index f3313ec5f7a..4656a6cdc2d 100644 --- a/src/mongo/db/pipeline/document_source_replace_root.h +++ b/src/mongo/db/pipeline/document_source_replace_root.h @@ -48,6 +48,8 @@ public: BSONElement elem, const boost::intrusive_ptr& pExpCtx); private: + // It is illegal to construct a DocumentSourceReplaceRoot directly, use createFromBson() + // instead. DocumentSourceReplaceRoot() = default; }; diff --git a/src/mongo/db/pipeline/document_source_sample.cpp b/src/mongo/db/pipeline/document_source_sample.cpp index a436d3e8a53..c7d98db7260 100644 --- a/src/mongo/db/pipeline/document_source_sample.cpp +++ b/src/mongo/db/pipeline/document_source_sample.cpp @@ -87,10 +87,6 @@ Value DocumentSourceSample::serialize(bool explain) const { return Value(DOC(getSourceName() << DOC("size" << _size))); } -void DocumentSourceSample::doInjectExpressionContext() { - _sortStage->injectExpressionContext(pExpCtx); -} - namespace { const BSONObj randSortSpec = BSON("$rand" << BSON("$meta" << "randVal")); diff --git a/src/mongo/db/pipeline/document_source_sample.h b/src/mongo/db/pipeline/document_source_sample.h index ed0eec9c180..28f1dd97b05 100644 --- a/src/mongo/db/pipeline/document_source_sample.h +++ b/src/mongo/db/pipeline/document_source_sample.h @@ -50,8 +50,6 @@ public: return _size; } - void doInjectExpressionContext() final; - static boost::intrusive_ptr createFromBson( BSONElement elem, const boost::intrusive_ptr& expCtx); diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp index f4b2299c9e8..9fc3b5bf105 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp @@ -52,6 +52,7 @@ DocumentSourceSampleFromRandomCursor::DocumentSourceSampleFromRandomCursor( : DocumentSource(pExpCtx), _size(size), _idField(std::move(idField)), + _seenDocs(pExpCtx->getValueComparator().makeUnorderedValueSet()), _nDocsInColl(nDocsInCollection) {} const char* DocumentSourceSampleFromRandomCursor::getSourceName() const { @@ -76,7 +77,7 @@ double smallestFromSampleOfUniform(PseudoRandom* prng, size_t N) { DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNext() { pExpCtx->checkForInterrupt(); - if (_seenDocs->size() >= static_cast(_size)) + if (_seenDocs.size() >= static_cast(_size)) return GetNextResult::makeEOF(); auto nextResult = getNextNonDuplicateDocument(); @@ -114,7 +115,7 @@ DocumentSource::GetNextResult DocumentSourceSampleFromRandomCursor::getNextNonDu << nextInput.getDocument().toString(), !idField.missing()); - if (_seenDocs->insert(std::move(idField)).second) { + if (_seenDocs.insert(std::move(idField)).second) { return nextInput; } LOG(1) << "$sample encountered duplicate document: " @@ -147,10 +148,6 @@ DocumentSource::GetDepsReturn DocumentSourceSampleFromRandomCursor::getDependenc return SEE_NEXT; } -void DocumentSourceSampleFromRandomCursor::doInjectExpressionContext() { - _seenDocs = pExpCtx->getValueComparator().makeUnorderedValueSet(); -} - intrusive_ptr DocumentSourceSampleFromRandomCursor::create( const intrusive_ptr& expCtx, long long size, @@ -158,7 +155,6 @@ intrusive_ptr DocumentSourceSampleFromRand long long nDocsInCollection) { intrusive_ptr source( new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection)); - source->injectExpressionContext(expCtx); return source; } } // mongo diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h index 6e8a2ae1d39..0d0ac39ca49 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.h @@ -44,8 +44,6 @@ public: Value serialize(bool explain = false) const final; GetDepsReturn getDependencies(DepsTracker* deps) const final; - void doInjectExpressionContext() final; - static boost::intrusive_ptr create( const boost::intrusive_ptr& expCtx, long long size, @@ -71,9 +69,8 @@ private: std::string _idField; // Keeps track of the documents that have been returned, since a random cursor is allowed to - // return duplicates. We use boost::optional to defer initialization until the ExpressionContext - // containing the correct comparator is injected. - boost::optional _seenDocs; + // return duplicates. + ValueUnorderedSet _seenDocs; // The approximate number of documents in the collection (includes orphans). const long long _nDocsInColl; diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp index 67dd7338395..35204d272a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.cpp +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.cpp @@ -98,10 +98,6 @@ DocumentSource::GetDepsReturn DocumentSourceSingleDocumentTransformation::getDep return _parsedTransform->addDependencies(deps); } -void DocumentSourceSingleDocumentTransformation::doInjectExpressionContext() { - _parsedTransform->injectExpressionContext(pExpCtx); -} - DocumentSource::GetModPathsReturn DocumentSourceSingleDocumentTransformation::getModifiedPaths() const { return _parsedTransform->getModifiedPaths(); diff --git a/src/mongo/db/pipeline/document_source_single_document_transformation.h b/src/mongo/db/pipeline/document_source_single_document_transformation.h index b2db5b0c3ae..1251cb454a4 100644 --- a/src/mongo/db/pipeline/document_source_single_document_transformation.h +++ b/src/mongo/db/pipeline/document_source_single_document_transformation.h @@ -57,8 +57,6 @@ public: virtual void optimize() = 0; virtual Document serialize(bool explain) const = 0; virtual DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const = 0; - virtual void injectExpressionContext( - const boost::intrusive_ptr& pExpCtx) = 0; virtual GetModPathsReturn getModifiedPaths() const = 0; }; @@ -75,7 +73,6 @@ public: Value serialize(bool explain) const final; Pipeline::SourceContainer::iterator doOptimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; - void doInjectExpressionContext() final; DocumentSource::GetDepsReturn getDependencies(DepsTracker* deps) const final; GetModPathsReturn getModifiedPaths() const final; diff --git a/src/mongo/db/pipeline/document_source_skip.cpp b/src/mongo/db/pipeline/document_source_skip.cpp index 27bb6d1d2a6..ff528f0634e 100644 --- a/src/mongo/db/pipeline/document_source_skip.cpp +++ b/src/mongo/db/pipeline/document_source_skip.cpp @@ -96,7 +96,6 @@ Pipeline::SourceContainer::iterator DocumentSourceSkip::doOptimizeAt( intrusive_ptr DocumentSourceSkip::create( const intrusive_ptr& pExpCtx, long long nToSkip) { intrusive_ptr skip(new DocumentSourceSkip(pExpCtx, nToSkip)); - skip->injectExpressionContext(pExpCtx); return skip; } diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp index 4f4f17da4f2..6554a892016 100644 --- a/src/mongo/db/pipeline/document_source_sort.cpp +++ b/src/mongo/db/pipeline/document_source_sort.cpp @@ -113,7 +113,7 @@ long long DocumentSourceSort::getLimit() const { void DocumentSourceSort::addKey(StringData fieldPath, bool ascending) { VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - vSortKey.push_back(ExpressionFieldPath::parse("$$ROOT." + fieldPath.toString(), vps)); + vSortKey.push_back(ExpressionFieldPath::parse(pExpCtx, "$$ROOT." + fieldPath.toString(), vps)); vAscending.push_back(ascending); } @@ -176,7 +176,6 @@ intrusive_ptr DocumentSourceSort::create( uint64_t maxMemoryUsageBytes) { intrusive_ptr pSort(new DocumentSourceSort(pExpCtx)); pSort->_maxMemoryUsageBytes = maxMemoryUsageBytes; - pSort->injectExpressionContext(pExpCtx); pSort->_sort = sortOrder.getOwned(); for (auto&& keyField : sortOrder) { @@ -197,7 +196,7 @@ intrusive_ptr DocumentSourceSort::create( VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - pSort->vSortKey.push_back(ExpressionMeta::parse(metaDoc.firstElement(), vps)); + pSort->vSortKey.push_back(ExpressionMeta::parse(pExpCtx, metaDoc.firstElement(), vps)); // If sorting by textScore, sort highest scores first. If sorting by randVal, order // doesn't matter, so just always use descending. diff --git a/src/mongo/db/pipeline/document_source_sort_by_count.h b/src/mongo/db/pipeline/document_source_sort_by_count.h index 2444d048259..69bf8d2c5e0 100644 --- a/src/mongo/db/pipeline/document_source_sort_by_count.h +++ b/src/mongo/db/pipeline/document_source_sort_by_count.h @@ -44,6 +44,8 @@ public: BSONElement elem, const boost::intrusive_ptr& pExpCtx); private: + // It is illegal to construct a DocumentSourceSortByCount directly, use createFromBson() + // instead. DocumentSourceSortByCount() = default; }; diff --git a/src/mongo/db/pipeline/document_source_tee_consumer.h b/src/mongo/db/pipeline/document_source_tee_consumer.h index 928ac8cba30..bfdbfbda1cb 100644 --- a/src/mongo/db/pipeline/document_source_tee_consumer.h +++ b/src/mongo/db/pipeline/document_source_tee_consumer.h @@ -38,7 +38,7 @@ namespace mongo { class Document; -struct ExpressionContext; +class ExpressionContext; class Value; /** diff --git a/src/mongo/db/pipeline/document_source_unwind.cpp b/src/mongo/db/pipeline/document_source_unwind.cpp index 368b29afbd4..683069597e0 100644 --- a/src/mongo/db/pipeline/document_source_unwind.cpp +++ b/src/mongo/db/pipeline/document_source_unwind.cpp @@ -183,7 +183,6 @@ intrusive_ptr DocumentSourceUnwind::create( FieldPath(unwindPath), preserveNullAndEmptyArrays, indexPath ? FieldPath(*indexPath) : boost::optional())); - source->injectExpressionContext(expCtx); return source; } diff --git a/src/mongo/db/pipeline/document_source_unwind_test.cpp b/src/mongo/db/pipeline/document_source_unwind_test.cpp index 48df219770a..8dcf1ed9a32 100644 --- a/src/mongo/db/pipeline/document_source_unwind_test.cpp +++ b/src/mongo/db/pipeline/document_source_unwind_test.cpp @@ -42,7 +42,7 @@ #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_source_unwind.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/query_test_service_context.h" #include "mongo/db/service_context.h" @@ -70,7 +70,8 @@ public: CheckResultsBase() : _queryServiceContext(stdx::make_unique()), _opCtx(_queryServiceContext->makeOperationContext()), - _ctx(new ExpressionContext(_opCtx.get(), AggregationRequest(NamespaceString(ns), {}))) {} + _ctx(new ExpressionContextForTest(_opCtx.get(), + AggregationRequest(NamespaceString(ns), {}))) {} virtual ~CheckResultsBase() {} @@ -141,7 +142,7 @@ protected: return expectedIndexedResultSetString(); } - intrusive_ptr ctx() const { + intrusive_ptr ctx() const { return _ctx; } @@ -248,7 +249,7 @@ private: unique_ptr _queryServiceContext; ServiceContext::UniqueOperationContext _opCtx; - intrusive_ptr _ctx; + intrusive_ptr _ctx; intrusive_ptr _unwind; }; diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index e9de0eb3472..d9a087e08c1 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -189,17 +189,20 @@ string Expression::removeFieldPrefix(const string& prefixedField) { return string(pPrefixedField + 1); } -intrusive_ptr Expression::parseObject(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr Expression::parseObject( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps) { if (obj.isEmpty()) { - return ExpressionObject::create({}); + return ExpressionObject::create(expCtx, {}); } if (obj.firstElementFieldName()[0] == '$') { // Assume this is an expression like {$add: [...]}. - return parseExpression(obj, vps); + return parseExpression(expCtx, obj, vps); } - return ExpressionObject::parse(obj, vps); + return ExpressionObject::parse(expCtx, obj, vps); } namespace { @@ -214,7 +217,10 @@ void Expression::registerExpression(string key, Parser parser) { parserMap[key] = parser; } -intrusive_ptr Expression::parseExpression(BSONObj obj, const VariablesParseState& vps) { +intrusive_ptr Expression::parseExpression( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps) { uassert(15983, str::stream() << "An object representing an expression must have exactly one " "field: " @@ -227,36 +233,40 @@ intrusive_ptr Expression::parseExpression(BSONObj obj, const Variabl uassert(ErrorCodes::InvalidPipelineOperator, str::stream() << "Unrecognized expression '" << opName << "'", op != parserMap.end()); - return op->second(obj.firstElement(), vps); + return op->second(expCtx, obj.firstElement(), vps); } -Expression::ExpressionVector ExpressionNary::parseArguments(BSONElement exprElement, - const VariablesParseState& vps) { +Expression::ExpressionVector ExpressionNary::parseArguments( + const boost::intrusive_ptr& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { ExpressionVector out; if (exprElement.type() == Array) { BSONForEach(elem, exprElement.Obj()) { - out.push_back(Expression::parseOperand(elem, vps)); + out.push_back(Expression::parseOperand(expCtx, elem, vps)); } } else { // Assume it's an operand that accepts a single argument. - out.push_back(Expression::parseOperand(exprElement, vps)); + out.push_back(Expression::parseOperand(expCtx, exprElement, vps)); } return out; } -intrusive_ptr Expression::parseOperand(BSONElement exprElement, - const VariablesParseState& vps) { +intrusive_ptr Expression::parseOperand( + const boost::intrusive_ptr& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { BSONType type = exprElement.type(); if (type == String && exprElement.valuestr()[0] == '$') { /* if we got here, this is a field path expression */ - return ExpressionFieldPath::parse(exprElement.str(), vps); + return ExpressionFieldPath::parse(expCtx, exprElement.str(), vps); } else if (type == Object) { - return Expression::parseObject(exprElement.Obj(), vps); + return Expression::parseObject(expCtx, exprElement.Obj(), vps); } else if (type == Array) { - return ExpressionArray::parse(exprElement, vps); + return ExpressionArray::parse(expCtx, exprElement, vps); } else { - return ExpressionConstant::parse(exprElement, vps); + return ExpressionConstant::parse(expCtx, exprElement, vps); } } @@ -615,13 +625,13 @@ const char* ExpressionCeil::getOpName() const { intrusive_ptr ExpressionCoerceToBool::create( const intrusive_ptr& expCtx, const intrusive_ptr& pExpression) { - intrusive_ptr pNew(new ExpressionCoerceToBool(pExpression)); - pNew->injectExpressionContext(expCtx); + intrusive_ptr pNew(new ExpressionCoerceToBool(expCtx, pExpression)); return pNew; } -ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr& pTheExpression) - : Expression(), pExpression(pTheExpression) {} +ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr& expCtx, + const intrusive_ptr& pTheExpression) + : Expression(expCtx), pExpression(pTheExpression) {} intrusive_ptr ExpressionCoerceToBool::optimize() { /* optimize the operand */ @@ -656,65 +666,68 @@ Value ExpressionCoerceToBool::serialize(bool explain) const { return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain)))); } -void ExpressionCoerceToBool::doInjectExpressionContext() { - // Inject our ExpressionContext into the operand. - pExpression->injectExpressionContext(getExpressionContext()); -} - /* ----------------------- ExpressionCompare --------------------------- */ REGISTER_EXPRESSION(cmp, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::CMP)); REGISTER_EXPRESSION(eq, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::EQ)); REGISTER_EXPRESSION(gt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GT)); REGISTER_EXPRESSION(gte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::GTE)); REGISTER_EXPRESSION(lt, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LT)); REGISTER_EXPRESSION(lte, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::LTE)); REGISTER_EXPRESSION(ne, stdx::bind(ExpressionCompare::parse, stdx::placeholders::_1, stdx::placeholders::_2, + stdx::placeholders::_3, ExpressionCompare::NE)); -intrusive_ptr ExpressionCompare::parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp op) { - intrusive_ptr expr = new ExpressionCompare(op); - ExpressionVector args = parseArguments(bsonExpr, vps); +intrusive_ptr ExpressionCompare::parse( + const boost::intrusive_ptr& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp op) { + intrusive_ptr expr = new ExpressionCompare(expCtx, op); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } -ExpressionCompare::ExpressionCompare(CmpOp theCmpOp) : cmpOp(theCmpOp) {} - boost::intrusive_ptr ExpressionCompare::create( + const boost::intrusive_ptr& expCtx, CmpOp cmpOp, const boost::intrusive_ptr& exprLeft, const boost::intrusive_ptr& exprRight) { - boost::intrusive_ptr expr = new ExpressionCompare(cmpOp); + boost::intrusive_ptr expr = new ExpressionCompare(expCtx, cmpOp); expr->vpOperand = {exprLeft, exprRight}; return expr; } @@ -828,23 +841,26 @@ Value ExpressionCond::evaluateInternal(Variables* vars) const { return vpOperand[idx]->evaluateInternal(vars); } -intrusive_ptr ExpressionCond::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr ExpressionCond::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps) { if (expr.type() != Object) { - return Base::parse(expr, vps); + return Base::parse(expCtx, expr, vps); } verify(str::equals(expr.fieldName(), "$cond")); - intrusive_ptr ret = new ExpressionCond(); + intrusive_ptr ret = new ExpressionCond(expCtx); ret->vpOperand.resize(3); const BSONObj args = expr.embeddedObject(); BSONForEach(arg, args) { if (str::equals(arg.fieldName(), "if")) { - ret->vpOperand[0] = parseOperand(arg, vps); + ret->vpOperand[0] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "then")) { - ret->vpOperand[1] = parseOperand(arg, vps); + ret->vpOperand[1] = parseOperand(expCtx, arg, vps); } else if (str::equals(arg.fieldName(), "else")) { - ret->vpOperand[2] = parseOperand(arg, vps); + ret->vpOperand[2] = parseOperand(expCtx, arg, vps); } else { uasserted(17083, str::stream() << "Unrecognized parameter to $cond: " << arg.fieldName()); @@ -865,20 +881,23 @@ const char* ExpressionCond::getOpName() const { /* ---------------------- ExpressionConstant --------------------------- */ -intrusive_ptr ExpressionConstant::parse(BSONElement exprElement, - const VariablesParseState& vps) { - return new ExpressionConstant(Value(exprElement)); +intrusive_ptr ExpressionConstant::parse( + const boost::intrusive_ptr& expCtx, + BSONElement exprElement, + const VariablesParseState& vps) { + return new ExpressionConstant(expCtx, Value(exprElement)); } intrusive_ptr ExpressionConstant::create( const intrusive_ptr& expCtx, const Value& pValue) { - intrusive_ptr pEC(new ExpressionConstant(pValue)); - pEC->injectExpressionContext(expCtx); + intrusive_ptr pEC(new ExpressionConstant(expCtx, pValue)); return pEC; } -ExpressionConstant::ExpressionConstant(const Value& pTheValue) : pValue(pTheValue) {} +ExpressionConstant::ExpressionConstant(const boost::intrusive_ptr& expCtx, + const Value& pTheValue) + : Expression(expCtx), pValue(pTheValue) {} intrusive_ptr ExpressionConstant::optimize() { @@ -907,8 +926,10 @@ const char* ExpressionConstant::getOpName() const { /* ---------------------- ExpressionDateToString ----------------------- */ REGISTER_EXPRESSION(dateToString, ExpressionDateToString::parse); -intrusive_ptr ExpressionDateToString::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr ExpressionDateToString::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps) { verify(str::equals(expr.fieldName(), "$dateToString")); uassert(18629, "$dateToString only supports an object as its argument", expr.type() == Object); @@ -939,11 +960,14 @@ intrusive_ptr ExpressionDateToString::parse(BSONElement expr, validateFormat(format); - return new ExpressionDateToString(format, parseOperand(dateElem, vps)); + return new ExpressionDateToString(expCtx, format, parseOperand(expCtx, dateElem, vps)); } -ExpressionDateToString::ExpressionDateToString(const string& format, intrusive_ptr date) - : _format(format), _date(date) {} +ExpressionDateToString::ExpressionDateToString( + const boost::intrusive_ptr& expCtx, + const string& format, + intrusive_ptr date) + : Expression(expCtx), _format(format), _date(date) {} intrusive_ptr ExpressionDateToString::optimize() { _date = _date->optimize(); @@ -1103,10 +1127,6 @@ void ExpressionDateToString::addDependencies(DepsTracker* deps) const { _date->addDependencies(deps); } -void ExpressionDateToString::doInjectExpressionContext() { - _date->injectExpressionContext(getExpressionContext()); -} - /* ---------------------- ExpressionDayOfMonth ------------------------- */ Value ExpressionDayOfMonth::evaluateInternal(Variables* vars) const { @@ -1198,16 +1218,20 @@ const char* ExpressionExp::getOpName() const { /* ---------------------- ExpressionObject --------------------------- */ -ExpressionObject::ExpressionObject(vector>>&& expressions) - : _expressions(std::move(expressions)) {} +ExpressionObject::ExpressionObject(const boost::intrusive_ptr& expCtx, + vector>>&& expressions) + : Expression(expCtx), _expressions(std::move(expressions)) {} intrusive_ptr ExpressionObject::create( + const boost::intrusive_ptr& expCtx, vector>>&& expressions) { - return new ExpressionObject(std::move(expressions)); + return new ExpressionObject(expCtx, std::move(expressions)); } -intrusive_ptr ExpressionObject::parse(BSONObj obj, - const VariablesParseState& vps) { +intrusive_ptr ExpressionObject::parse( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps) { // Make sure we don't have any duplicate field names. stdx::unordered_set specifiedFields; @@ -1223,10 +1247,10 @@ intrusive_ptr ExpressionObject::parse(BSONObj obj, << obj.toString(), specifiedFields.find(fieldName) == specifiedFields.end()); specifiedFields.insert(fieldName); - expressions.emplace_back(fieldName, parseOperand(elem, vps)); + expressions.emplace_back(fieldName, parseOperand(expCtx, elem, vps)); } - return new ExpressionObject{std::move(expressions)}; + return new ExpressionObject{expCtx, std::move(expressions)}; } intrusive_ptr ExpressionObject::optimize() { @@ -1258,22 +1282,19 @@ Value ExpressionObject::serialize(bool explain) const { return outputDoc.freezeToValue(); } -void ExpressionObject::doInjectExpressionContext() { - for (auto&& pair : _expressions) { - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* --------------------- ExpressionFieldPath --------------------------- */ // this is the old deprecated version only used by tests not using variables -intrusive_ptr ExpressionFieldPath::create(const string& fieldPath) { - return new ExpressionFieldPath("CURRENT." + fieldPath, Variables::ROOT_ID); +intrusive_ptr ExpressionFieldPath::create( + const boost::intrusive_ptr& expCtx, const string& fieldPath) { + return new ExpressionFieldPath(expCtx, "CURRENT." + fieldPath, Variables::ROOT_ID); } // this is the new version that supports every syntax -intrusive_ptr ExpressionFieldPath::parse(const string& raw, - const VariablesParseState& vps) { +intrusive_ptr ExpressionFieldPath::parse( + const boost::intrusive_ptr& expCtx, + const string& raw, + const VariablesParseState& vps) { uassert(16873, str::stream() << "FieldPath '" << raw << "' doesn't start with $", raw.c_str()[0] == '$'); // c_str()[0] is always a valid reference. @@ -1287,16 +1308,18 @@ intrusive_ptr ExpressionFieldPath::parse(const string& raw, const StringData fieldPath = rawSD.substr(2); // strip off $$ const StringData varName = fieldPath.substr(0, fieldPath.find('.')); Variables::uassertValidNameForUserRead(varName); - return new ExpressionFieldPath(fieldPath.toString(), vps.getVariable(varName)); + return new ExpressionFieldPath(expCtx, fieldPath.toString(), vps.getVariable(varName)); } else { - return new ExpressionFieldPath("CURRENT." + raw.substr(1), // strip the "$" prefix + return new ExpressionFieldPath(expCtx, + "CURRENT." + raw.substr(1), // strip the "$" prefix vps.getVariable("CURRENT")); } } - -ExpressionFieldPath::ExpressionFieldPath(const string& theFieldPath, Variables::Id variable) - : _fieldPath(theFieldPath), _variable(variable) {} +ExpressionFieldPath::ExpressionFieldPath(const boost::intrusive_ptr& expCtx, + const string& theFieldPath, + Variables::Id variable) + : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {} intrusive_ptr ExpressionFieldPath::optimize() { /* nothing can be done for these */ @@ -1384,8 +1407,10 @@ Value ExpressionFieldPath::serialize(bool explain) const { /* ------------------------- ExpressionFilter ----------------------------- */ REGISTER_EXPRESSION(filter, ExpressionFilter::parse); -intrusive_ptr ExpressionFilter::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr ExpressionFilter::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$filter")); uassert(28646, "$filter only supports an object as its argument", expr.type() == Object); @@ -1411,7 +1436,7 @@ intrusive_ptr ExpressionFilter::parse(BSONElement expr, uassert(28650, "Missing 'cond' parameter to $filter", !condElem.eoo()); // Parse "input", only has outer variables. - intrusive_ptr input = parseOperand(inputElem, vpsIn); + intrusive_ptr input = parseOperand(expCtx, inputElem, vpsIn); // Parse "as". VariablesParseState vpsSub(vpsIn); // vpsSub gets our variable, vpsIn doesn't. @@ -1423,16 +1448,19 @@ intrusive_ptr ExpressionFilter::parse(BSONElement expr, Variables::Id varId = vpsSub.defineVariable(varName); // Parse "cond", has access to "as" variable. - intrusive_ptr cond = parseOperand(condElem, vpsSub); + intrusive_ptr cond = parseOperand(expCtx, condElem, vpsSub); - return new ExpressionFilter(std::move(varName), varId, std::move(input), std::move(cond)); + return new ExpressionFilter( + expCtx, std::move(varName), varId, std::move(input), std::move(cond)); } -ExpressionFilter::ExpressionFilter(string varName, +ExpressionFilter::ExpressionFilter(const boost::intrusive_ptr& expCtx, + string varName, Variables::Id varId, intrusive_ptr input, intrusive_ptr filter) - : _varName(std::move(varName)), + : Expression(expCtx), + _varName(std::move(varName)), _varId(varId), _input(std::move(input)), _filter(std::move(filter)) {} @@ -1483,11 +1511,6 @@ void ExpressionFilter::addDependencies(DepsTracker* deps) const { _filter->addDependencies(deps); } -void ExpressionFilter::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _filter->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionFloor -------------------------- */ Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { @@ -1512,7 +1535,10 @@ const char* ExpressionFloor::getOpName() const { /* ------------------------- ExpressionLet ----------------------------- */ REGISTER_EXPRESSION(let, ExpressionLet::parse); -intrusive_ptr ExpressionLet::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr ExpressionLet::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$let")); uassert(16874, "$let only supports an object as its argument", expr.type() == Object); @@ -1543,17 +1569,20 @@ intrusive_ptr ExpressionLet::parse(BSONElement expr, const Variables Variables::uassertValidNameForUserWrite(varName); Variables::Id id = vpsSub.defineVariable(varName); - vars[id] = NameAndExpression(varName, parseOperand(varElem, vpsIn)); // only has outer vars + vars[id] = NameAndExpression(varName, + parseOperand(expCtx, varElem, vpsIn)); // only has outer vars } // parse "in" - intrusive_ptr subExpression = parseOperand(inElem, vpsSub); // has our vars + intrusive_ptr subExpression = parseOperand(expCtx, inElem, vpsSub); // has our vars - return new ExpressionLet(vars, subExpression); + return new ExpressionLet(expCtx, vars, subExpression); } -ExpressionLet::ExpressionLet(const VariableMap& vars, intrusive_ptr subExpression) - : _variables(vars), _subExpression(subExpression) {} +ExpressionLet::ExpressionLet(const boost::intrusive_ptr& expCtx, + const VariableMap& vars, + intrusive_ptr subExpression) + : Expression(expCtx), _variables(vars), _subExpression(subExpression) {} intrusive_ptr ExpressionLet::optimize() { if (_variables.empty()) { @@ -1603,18 +1632,14 @@ void ExpressionLet::addDependencies(DepsTracker* deps) const { _subExpression->addDependencies(deps); } -void ExpressionLet::doInjectExpressionContext() { - _subExpression->injectExpressionContext(getExpressionContext()); - for (auto&& variable : _variables) { - variable.second.expression->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionMap ----------------------------- */ REGISTER_EXPRESSION(map, ExpressionMap::parse); -intrusive_ptr ExpressionMap::parse(BSONElement expr, const VariablesParseState& vpsIn) { +intrusive_ptr ExpressionMap::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { verify(str::equals(expr.fieldName(), "$map")); uassert(16878, "$map only supports an object as its argument", expr.type() == Object); @@ -1641,7 +1666,8 @@ intrusive_ptr ExpressionMap::parse(BSONElement expr, const Variables uassert(16882, "Missing 'in' parameter to $map", !inElem.eoo()); // parse "input" - intrusive_ptr input = parseOperand(inputElem, vpsIn); // only has outer vars + intrusive_ptr input = + parseOperand(expCtx, inputElem, vpsIn); // only has outer vars // parse "as" VariablesParseState vpsSub(vpsIn); // vpsSub gets our vars, vpsIn doesn't. @@ -1653,16 +1679,18 @@ intrusive_ptr ExpressionMap::parse(BSONElement expr, const Variables Variables::Id varId = vpsSub.defineVariable(varName); // parse "in" - intrusive_ptr in = parseOperand(inElem, vpsSub); // has access to map variable + intrusive_ptr in = + parseOperand(expCtx, inElem, vpsSub); // has access to map variable - return new ExpressionMap(varName, varId, input, in); + return new ExpressionMap(expCtx, varName, varId, input, in); } -ExpressionMap::ExpressionMap(const string& varName, +ExpressionMap::ExpressionMap(const boost::intrusive_ptr& expCtx, + const string& varName, Variables::Id varId, intrusive_ptr input, intrusive_ptr each) - : _varName(varName), _varId(varId), _input(input), _each(each) {} + : Expression(expCtx), _varName(varName), _varId(varId), _input(input), _each(each) {} intrusive_ptr ExpressionMap::optimize() { // TODO handle when _input is constant @@ -1711,27 +1739,26 @@ void ExpressionMap::addDependencies(DepsTracker* deps) const { _each->addDependencies(deps); } -void ExpressionMap::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _each->injectExpressionContext(getExpressionContext()); -} - /* ------------------------- ExpressionMeta ----------------------------- */ REGISTER_EXPRESSION(meta, ExpressionMeta::parse); -intrusive_ptr ExpressionMeta::parse(BSONElement expr, - const VariablesParseState& vpsIn) { +intrusive_ptr ExpressionMeta::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn) { uassert(17307, "$meta only supports string arguments", expr.type() == String); if (expr.valueStringData() == "textScore") { - return new ExpressionMeta(MetaType::TEXT_SCORE); + return new ExpressionMeta(expCtx, MetaType::TEXT_SCORE); } else if (expr.valueStringData() == "randVal") { - return new ExpressionMeta(MetaType::RAND_VAL); + return new ExpressionMeta(expCtx, MetaType::RAND_VAL); } else { uasserted(17308, "Unsupported argument to $meta: " + expr.String()); } } -ExpressionMeta::ExpressionMeta(MetaType metaType) : _metaType(metaType) {} +ExpressionMeta::ExpressionMeta(const boost::intrusive_ptr& expCtx, + MetaType metaType) + : Expression(expCtx), _metaType(metaType) {} Value ExpressionMeta::serialize(bool explain) const { switch (_metaType) { @@ -2399,12 +2426,6 @@ Value ExpressionNary::serialize(bool explain) const { return Value(DOC(getOpName() << array)); } -void ExpressionNary::doInjectExpressionContext() { - for (auto&& operand : vpOperand) { - operand->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionNot ----------------------------- */ Value ExpressionNot::evaluateInternal(Variables* vars) const { @@ -2491,8 +2512,9 @@ const char* ExpressionOr::getOpName() const { /* ----------------------- ExpressionPow ---------------------------- */ -intrusive_ptr ExpressionPow::create(Value base, Value exp) { - intrusive_ptr expr(new ExpressionPow()); +intrusive_ptr ExpressionPow::create( + const boost::intrusive_ptr& expCtx, Value base, Value exp) { + intrusive_ptr expr(new ExpressionPow(expCtx)); expr->vpOperand.push_back( ExpressionConstant::create(expr->getExpressionContext(), std::move(base))); expr->vpOperand.push_back( @@ -2723,14 +2745,16 @@ const char* ExpressionRange::getOpName() const { /* ------------------------ ExpressionReduce ------------------------------ */ REGISTER_EXPRESSION(reduce, ExpressionReduce::parse); -intrusive_ptr ExpressionReduce::parse(BSONElement expr, - const VariablesParseState& vps) { +intrusive_ptr ExpressionReduce::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40075, str::stream() << "$reduce requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr reduce(new ExpressionReduce()); + intrusive_ptr reduce(new ExpressionReduce(expCtx)); // vpsSub is used only to parse 'in', which must have access to $$this and $$value. VariablesParseState vpsSub(vps); @@ -2741,11 +2765,11 @@ intrusive_ptr ExpressionReduce::parse(BSONElement expr, auto field = elem.fieldNameStringData(); if (field == "input") { - reduce->_input = parseOperand(elem, vps); + reduce->_input = parseOperand(expCtx, elem, vps); } else if (field == "initialValue") { - reduce->_initial = parseOperand(elem, vps); + reduce->_initial = parseOperand(expCtx, elem, vps); } else if (field == "in") { - reduce->_in = parseOperand(elem, vpsSub); + reduce->_in = parseOperand(expCtx, elem, vpsSub); } else { uasserted(40076, str::stream() << "$reduce found an unknown argument: " << field); } @@ -2802,12 +2826,6 @@ Value ExpressionReduce::serialize(bool explain) const { {"in", _in->serialize(explain)}}}}); } -void ExpressionReduce::doInjectExpressionContext() { - _input->injectExpressionContext(getExpressionContext()); - _initial->injectExpressionContext(getExpressionContext()); - _in->injectExpressionContext(getExpressionContext()); -} - /* ------------------------ ExpressionReverseArray ------------------------ */ Value ExpressionReverseArray::evaluateInternal(Variables* vars) const { @@ -3031,8 +3049,10 @@ Value ExpressionSetIsSubset::evaluateInternal(Variables* vars) const { */ class ExpressionSetIsSubset::Optimized : public ExpressionSetIsSubset { public: - Optimized(const ValueSet& cachedRhsSet, const ExpressionVector& operands) - : _cachedRhsSet(cachedRhsSet) { + Optimized(const boost::intrusive_ptr& expCtx, + const ValueSet& cachedRhsSet, + const ExpressionVector& operands) + : ExpressionSetIsSubset(expCtx), _cachedRhsSet(cachedRhsSet) { vpOperand = operands; } @@ -3068,9 +3088,10 @@ intrusive_ptr ExpressionSetIsSubset::optimize() { << typeName(rhs.getType()), rhs.isArray()); - intrusive_ptr optimizedWithConstant(new Optimized( - arrayToSet(rhs, getExpressionContext()->getValueComparator()), vpOperand)); - optimizedWithConstant->injectExpressionContext(getExpressionContext()); + intrusive_ptr optimizedWithConstant( + new Optimized(this->getExpressionContext(), + arrayToSet(rhs, getExpressionContext()->getValueComparator()), + vpOperand)); return optimizedWithConstant; } return optimized; @@ -3579,14 +3600,16 @@ Value ExpressionSwitch::evaluateInternal(Variables* vars) const { return _default->evaluateInternal(vars); } -boost::intrusive_ptr ExpressionSwitch::parse(BSONElement expr, - const VariablesParseState& vps) { +boost::intrusive_ptr ExpressionSwitch::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(40060, str::stream() << "$switch requires an object as an argument, found: " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr expression(new ExpressionSwitch()); + intrusive_ptr expression(new ExpressionSwitch(expCtx)); for (auto&& elem : expr.Obj()) { auto field = elem.fieldNameStringData(); @@ -3610,9 +3633,9 @@ boost::intrusive_ptr ExpressionSwitch::parse(BSONElement expr, auto branchField = branchElement.fieldNameStringData(); if (branchField == "case") { - branchExpression.first = parseOperand(branchElement, vps); + branchExpression.first = parseOperand(expCtx, branchElement, vps); } else if (branchField == "then") { - branchExpression.second = parseOperand(branchElement, vps); + branchExpression.second = parseOperand(expCtx, branchElement, vps); } else { uasserted(40063, str::stream() << "$switch found an unknown argument to a branch: " @@ -3631,7 +3654,7 @@ boost::intrusive_ptr ExpressionSwitch::parse(BSONElement expr, } } else if (field == "default") { // Optional, arbitrary expression. - expression->_default = parseOperand(elem, vps); + expression->_default = parseOperand(expCtx, elem, vps); } else { uasserted(40067, str::stream() << "$switch found an unknown argument: " << field); } @@ -3686,17 +3709,6 @@ Value ExpressionSwitch::serialize(bool explain) const { return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}}); } -void ExpressionSwitch::doInjectExpressionContext() { - if (_default) { - _default->injectExpressionContext(getExpressionContext()); - } - - for (auto&& pair : _branches) { - pair.first->injectExpressionContext(getExpressionContext()); - pair.second->injectExpressionContext(getExpressionContext()); - } -} - /* ------------------------- ExpressionToLower ----------------------------- */ Value ExpressionToLower::evaluateInternal(Variables* vars) const { @@ -3956,13 +3968,16 @@ const char* ExpressionYear::getOpName() const { /* -------------------------- ExpressionZip ------------------------------ */ REGISTER_EXPRESSION(zip, ExpressionZip::parse); -intrusive_ptr ExpressionZip::parse(BSONElement expr, const VariablesParseState& vps) { +intrusive_ptr ExpressionZip::parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps) { uassert(34460, str::stream() << "$zip only supports an object as an argument, found " << typeName(expr.type()), expr.type() == Object); - intrusive_ptr newZip(new ExpressionZip()); + intrusive_ptr newZip(new ExpressionZip(expCtx)); for (auto&& elem : expr.Obj()) { const auto field = elem.fieldNameStringData(); @@ -3972,7 +3987,7 @@ intrusive_ptr ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_inputs.push_back(parseOperand(subExpr, vps)); + newZip->_inputs.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "defaults") { uassert(34462, @@ -3980,7 +3995,7 @@ intrusive_ptr ExpressionZip::parse(BSONElement expr, const Variables << typeName(elem.type()), elem.type() == Array); for (auto&& subExpr : elem.Array()) { - newZip->_defaults.push_back(parseOperand(subExpr, vps)); + newZip->_defaults.push_back(parseOperand(expCtx, subExpr, vps)); } } else if (field == "useLongestLength") { uassert(34463, @@ -4123,14 +4138,4 @@ void ExpressionZip::addDependencies(DepsTracker* deps) const { }); } -void ExpressionZip::doInjectExpressionContext() { - for (auto&& expr : _inputs) { - expr->injectExpressionContext(getExpressionContext()); - } - - for (auto&& expr : _defaults) { - expr->injectExpressionContext(getExpressionContext()); - } -} - } // namespace mongo diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 00e771784c2..ef9a3ae3f19 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -172,8 +172,8 @@ private: class Expression : public IntrusiveCounterUnsigned { public: - using Parser = - stdx::function(BSONElement, const VariablesParseState&)>; + using Parser = stdx::function( + const boost::intrusive_ptr&, BSONElement, const VariablesParseState&)>; virtual ~Expression(){}; @@ -235,8 +235,10 @@ public: * Calls parseExpression() on any sub-document (including possibly the entire document) which * consists of a single field name starting with a '$'. */ - static boost::intrusive_ptr parseObject(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr parseObject( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONObj which has already been determined to be a functional expression. @@ -244,8 +246,10 @@ public: * Throws an error if 'obj' does not contain exactly one field, or if that field's name does not * match a registered expression name. */ - static boost::intrusive_ptr parseExpression(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr parseExpression( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * Parses a BSONElement which is an argument to an Expression. @@ -254,8 +258,10 @@ public: * parseObject(), ExpressionFieldPath::parse(), ExpressionArray::parse(), or * ExpressionConstant::parse() as necessary. */ - static boost::intrusive_ptr parseOperand(BSONElement exprElement, - const VariablesParseState& vps); + static boost::intrusive_ptr parseOperand( + const boost::intrusive_ptr& expCtx, + BSONElement exprElement, + const VariablesParseState& vps); /* Produce a field path std::string with the field prefix removed. @@ -282,24 +288,10 @@ public: */ static void registerExpression(std::string key, Parser parser); - /** - * Injects the ExpressionContext so that it may be used during evaluation of the Expression. - * Construction of expressions is done at parse time, but the ExpressionContext isn't finalized - * until later, at which point it is injected using this method. - */ - void injectExpressionContext(const boost::intrusive_ptr& expCtx) { - _expCtx = expCtx; - doInjectExpressionContext(); - } - protected: - typedef std::vector> ExpressionVector; + Expression(const boost::intrusive_ptr& expCtx) : _expCtx(expCtx) {} - /** - * Expressions which need to update their internal state when attaching to a new - * ExpressionContext should override this method. - */ - virtual void doInjectExpressionContext() {} + typedef std::vector> ExpressionVector; const boost::intrusive_ptr& getExpressionContext() const { return _expCtx; @@ -344,12 +336,13 @@ public: /// Allow subclasses the opportunity to validate arguments at parse time. virtual void validateArguments(const ExpressionVector& args) const {} - static ExpressionVector parseArguments(BSONElement bsonExpr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static ExpressionVector parseArguments(const boost::intrusive_ptr& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); protected: - ExpressionNary() {} + explicit ExpressionNary(const boost::intrusive_ptr& expCtx) + : Expression(expCtx) {} ExpressionVector vpOperand; }; @@ -358,19 +351,29 @@ protected: template class ExpressionNaryBase : public ExpressionNary { public: - static boost::intrusive_ptr parse(BSONElement bsonExpr, - const VariablesParseState& vps) { - boost::intrusive_ptr expr = new SubClass(); - ExpressionVector args = parseArguments(bsonExpr, vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps) { + boost::intrusive_ptr expr = new SubClass(expCtx); + ExpressionVector args = parseArguments(expCtx, bsonExpr, vps); expr->validateArguments(args); expr->vpOperand = args; return expr; } + +protected: + explicit ExpressionNaryBase(const boost::intrusive_ptr& expCtx) + : ExpressionNary(expCtx) {} }; /// Inherit from this class if your expression takes a variable number of arguments. template -class ExpressionVariadic : public ExpressionNaryBase {}; +class ExpressionVariadic : public ExpressionNaryBase { +public: + explicit ExpressionVariadic(const boost::intrusive_ptr& expCtx) + : ExpressionNaryBase(expCtx) {} +}; /** * Inherit from this class if your expression can take a range of arguments, e.g. if it has some @@ -379,6 +382,9 @@ class ExpressionVariadic : public ExpressionNaryBase {}; template class ExpressionRangedArity : public ExpressionNaryBase { public: + explicit ExpressionRangedArity(const boost::intrusive_ptr& expCtx) + : ExpressionNaryBase(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(28667, mongoutils::str::stream() << "Expression " << this->getOpName() @@ -397,6 +403,9 @@ public: template class ExpressionFixedArity : public ExpressionNaryBase { public: + explicit ExpressionFixedArity(const boost::intrusive_ptr& expCtx) + : ExpressionNaryBase(expCtx) {} + void validateArguments(const Expression::ExpressionVector& args) const override { uassert(16020, mongoutils::str::stream() << "Expression " << this->getOpName() << " takes exactly " @@ -416,9 +425,11 @@ template class ExpressionFromAccumulator : public ExpressionVariadic> { public: + explicit ExpressionFromAccumulator(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic>(expCtx) {} + Value evaluateInternal(Variables* vars) const final { - Accumulator accum; - accum.injectExpressionContext(this->getExpressionContext()); + Accumulator accum(this->getExpressionContext()); const size_t n = this->vpOperand.size(); // If a single array arg is given, loop through it passing each member to the accumulator. // If a single, non-array arg is given, pass it directly to the accumulator. @@ -446,15 +457,15 @@ public: if (this->vpOperand.size() == 1) { return false; } - return Accumulator().isAssociative(); + return Accumulator(this->getExpressionContext()).isAssociative(); } bool isCommutative() const final { - return Accumulator().isCommutative(); + return Accumulator(this->getExpressionContext()).isCommutative(); } const char* getOpName() const final { - return Accumulator().getOpName(); + return Accumulator(this->getExpressionContext()).getOpName(); } }; @@ -464,6 +475,9 @@ public: template class ExpressionSingleNumericArg : public ExpressionFixedArity { public: + explicit ExpressionSingleNumericArg(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + virtual ~ExpressionSingleNumericArg() {} Value evaluateInternal(Variables* vars) const final { @@ -484,6 +498,10 @@ public: class ExpressionAbs final : public ExpressionSingleNumericArg { +public: + explicit ExpressionAbs(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -491,6 +509,9 @@ class ExpressionAbs final : public ExpressionSingleNumericArg { class ExpressionAdd final : public ExpressionVariadic { public: + explicit ExpressionAdd(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -506,6 +527,9 @@ public: class ExpressionAllElementsTrue final : public ExpressionFixedArity { public: + explicit ExpressionAllElementsTrue(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -513,6 +537,9 @@ public: class ExpressionAnd final : public ExpressionVariadic { public: + explicit ExpressionAnd(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + boost::intrusive_ptr optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -529,6 +556,9 @@ public: class ExpressionAnyElementTrue final : public ExpressionFixedArity { public: + explicit ExpressionAnyElementTrue(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -536,7 +566,9 @@ public: class ExpressionArray final : public ExpressionVariadic { public: - // virtuals from ExpressionNary + explicit ExpressionArray(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; Value serialize(bool explain) const final; const char* getOpName() const final; @@ -545,6 +577,9 @@ public: class ExpressionArrayElemAt final : public ExpressionFixedArity { public: + explicit ExpressionArrayElemAt(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -552,6 +587,9 @@ public: class ExpressionCeil final : public ExpressionSingleNumericArg { public: + explicit ExpressionCeil(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -568,10 +606,9 @@ public: const boost::intrusive_ptr& expCtx, const boost::intrusive_ptr& pExpression); - void doInjectExpressionContext() final; - private: - explicit ExpressionCoerceToBool(const boost::intrusive_ptr& pExpression); + ExpressionCoerceToBool(const boost::intrusive_ptr& expCtx, + const boost::intrusive_ptr& pExpression); boost::intrusive_ptr pExpression; }; @@ -593,16 +630,20 @@ public: CMP = 6, // return -1, 0, 1 for a < b, a == b, a > b }; + ExpressionCompare(const boost::intrusive_ptr& expCtx, CmpOp cmpOp) + : ExpressionFixedArity(expCtx), cmpOp(cmpOp) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr parse(BSONElement bsonExpr, - const VariablesParseState& vps, - CmpOp cmpOp); - - explicit ExpressionCompare(CmpOp cmpOp); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps, + CmpOp cmpOp); static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx, CmpOp cmpOp, const boost::intrusive_ptr& exprLeft, const boost::intrusive_ptr& exprRight); @@ -614,6 +655,9 @@ private: class ExpressionConcat final : public ExpressionVariadic { public: + explicit ExpressionConcat(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -625,6 +669,9 @@ public: class ExpressionConcatArrays final : public ExpressionVariadic { public: + explicit ExpressionConcatArrays(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -635,13 +682,19 @@ public: class ExpressionCond final : public ExpressionFixedArity { - typedef ExpressionFixedArity Base; - public: + explicit ExpressionCond(const boost::intrusive_ptr& expCtx) : Base(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); + +private: + typedef ExpressionFixedArity Base; }; @@ -657,8 +710,10 @@ public: static boost::intrusive_ptr create( const boost::intrusive_ptr& expCtx, const Value& pValue); - static boost::intrusive_ptr parse(BSONElement bsonExpr, - const VariablesParseState& vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement bsonExpr, + const VariablesParseState& vps); /* Get the constant value represented by this Expression. @@ -670,7 +725,7 @@ public: } private: - explicit ExpressionConstant(const Value& pValue); + ExpressionConstant(const boost::intrusive_ptr& expCtx, const Value& pValue); Value pValue; }; @@ -682,12 +737,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionDateToString(const std::string& format, // the format string + ExpressionDateToString(const boost::intrusive_ptr& expCtx, + const std::string& format, // the format string boost::intrusive_ptr date); // the date to format // Will uassert on invalid data @@ -705,6 +762,9 @@ private: class ExpressionDayOfMonth final : public ExpressionFixedArity { public: + explicit ExpressionDayOfMonth(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -716,6 +776,9 @@ public: class ExpressionDayOfWeek final : public ExpressionFixedArity { public: + explicit ExpressionDayOfWeek(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -728,6 +791,9 @@ public: class ExpressionDayOfYear final : public ExpressionFixedArity { public: + explicit ExpressionDayOfYear(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -740,12 +806,19 @@ public: class ExpressionDivide final : public ExpressionFixedArity { public: + explicit ExpressionDivide(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionExp final : public ExpressionSingleNumericArg { +public: + explicit ExpressionExp(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -771,18 +844,23 @@ public: indicator @returns the newly created field path expression */ - static boost::intrusive_ptr create(const std::string& fieldPath); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx, const std::string& fieldPath); /// Like create(), but works with the raw std::string from the user with the "$" prefixes. - static boost::intrusive_ptr parse(const std::string& raw, - const VariablesParseState& vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + const std::string& raw, + const VariablesParseState& vps); const FieldPath& getFieldPath() const { return _fieldPath; } private: - ExpressionFieldPath(const std::string& fieldPath, Variables::Id variable); + ExpressionFieldPath(const boost::intrusive_ptr& expCtx, + const std::string& fieldPath, + Variables::Id variable); /* Internal implementation of evaluateInternal(), used recursively. @@ -814,12 +892,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: - ExpressionFilter(std::string varName, + ExpressionFilter(const boost::intrusive_ptr& expCtx, + std::string varName, Variables::Id varId, boost::intrusive_ptr input, boost::intrusive_ptr filter); @@ -837,6 +917,9 @@ private: class ExpressionFloor final : public ExpressionSingleNumericArg { public: + explicit ExpressionFloor(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -844,6 +927,9 @@ public: class ExpressionHour final : public ExpressionFixedArity { public: + explicit ExpressionHour(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -855,6 +941,9 @@ public: class ExpressionIfNull final : public ExpressionFixedArity { public: + explicit ExpressionIfNull(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -862,6 +951,9 @@ public: class ExpressionIn final : public ExpressionFixedArity { public: + explicit ExpressionIn(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -869,6 +961,9 @@ public: class ExpressionIndexOfArray final : public ExpressionRangedArity { public: + explicit ExpressionIndexOfArray(const boost::intrusive_ptr& expCtx) + : ExpressionRangedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -876,6 +971,9 @@ public: class ExpressionIndexOfBytes final : public ExpressionRangedArity { public: + explicit ExpressionIndexOfBytes(const boost::intrusive_ptr& expCtx) + : ExpressionRangedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -886,6 +984,9 @@ public: */ class ExpressionIndexOfCP final : public ExpressionRangedArity { public: + explicit ExpressionIndexOfCP(const boost::intrusive_ptr& expCtx) + : ExpressionRangedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -898,9 +999,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); struct NameAndExpression { NameAndExpression() {} @@ -914,23 +1016,37 @@ public: typedef std::map VariableMap; private: - ExpressionLet(const VariableMap& vars, boost::intrusive_ptr subExpression); + ExpressionLet(const boost::intrusive_ptr& expCtx, + const VariableMap& vars, + boost::intrusive_ptr subExpression); VariableMap _variables; boost::intrusive_ptr _subExpression; }; class ExpressionLn final : public ExpressionSingleNumericArg { +public: + explicit ExpressionLn(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; class ExpressionLog final : public ExpressionFixedArity { +public: + explicit ExpressionLog(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionLog10 final : public ExpressionSingleNumericArg { +public: + explicit ExpressionLog10(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -942,12 +1058,14 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); - - void doInjectExpressionContext() final; + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: ExpressionMap( + const boost::intrusive_ptr& expCtx, const std::string& varName, // name of variable to set Variables::Id varId, // id of variable to set boost::intrusive_ptr input, // yields array to iterate @@ -965,7 +1083,10 @@ public: Value evaluateInternal(Variables* vars) const final; void addDependencies(DepsTracker* deps) const final; - static boost::intrusive_ptr parse(BSONElement expr, const VariablesParseState& vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vps); private: enum MetaType { @@ -973,13 +1094,16 @@ private: RAND_VAL, }; - ExpressionMeta(MetaType metaType); + ExpressionMeta(const boost::intrusive_ptr& expCtx, MetaType metaType); MetaType _metaType; }; class ExpressionMillisecond final : public ExpressionFixedArity { public: + explicit ExpressionMillisecond(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -989,6 +1113,9 @@ public: class ExpressionMinute final : public ExpressionFixedArity { public: + explicit ExpressionMinute(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1000,6 +1127,9 @@ public: class ExpressionMod final : public ExpressionFixedArity { public: + explicit ExpressionMod(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1007,6 +1137,9 @@ public: class ExpressionMultiply final : public ExpressionVariadic { public: + explicit ExpressionMultiply(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1022,6 +1155,9 @@ public: class ExpressionMonth final : public ExpressionFixedArity { public: + explicit ExpressionMonth(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1034,6 +1170,9 @@ public: class ExpressionNot final : public ExpressionFixedArity { public: + explicit ExpressionNot(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1055,13 +1194,16 @@ public: Value serialize(bool explain) const final; static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx, std::vector>>&& expressions); /** * Parses and constructs an ExpressionObject from 'obj'. */ - static boost::intrusive_ptr parse(BSONObj obj, - const VariablesParseState& vps); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONObj obj, + const VariablesParseState& vps); /** * This ExpressionObject must outlive the returned vector. @@ -1071,10 +1213,9 @@ public: return _expressions; } - void doInjectExpressionContext() final; - private: ExpressionObject( + const boost::intrusive_ptr& expCtx, std::vector>>&& expressions); // The mapping from field name to expression within this object. This needs to respect the order @@ -1085,6 +1226,9 @@ private: class ExpressionOr final : public ExpressionVariadic { public: + explicit ExpressionOr(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + boost::intrusive_ptr optimize() final; Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1100,7 +1244,11 @@ public: class ExpressionPow final : public ExpressionFixedArity { public: - static boost::intrusive_ptr create(Value base, Value exp); + explicit ExpressionPow(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx, Value base, Value exp); private: Value evaluateInternal(Variables* vars) const final; @@ -1109,6 +1257,10 @@ private: class ExpressionRange final : public ExpressionRangedArity { +public: + explicit ExpressionRange(const boost::intrusive_ptr& expCtx) + : ExpressionRangedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1116,15 +1268,18 @@ class ExpressionRange final : public ExpressionRangedArity& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr optimize() final; - static boost::intrusive_ptr parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: boost::intrusive_ptr _input; boost::intrusive_ptr _initial; @@ -1137,6 +1292,9 @@ private: class ExpressionSecond final : public ExpressionFixedArity { public: + explicit ExpressionSecond(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1148,6 +1306,9 @@ public: class ExpressionSetDifference final : public ExpressionFixedArity { public: + explicit ExpressionSetDifference(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1155,6 +1316,9 @@ public: class ExpressionSetEquals final : public ExpressionVariadic { public: + explicit ExpressionSetEquals(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; void validateArguments(const ExpressionVector& args) const final; @@ -1163,6 +1327,9 @@ public: class ExpressionSetIntersection final : public ExpressionVariadic { public: + explicit ExpressionSetIntersection(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1179,6 +1346,9 @@ public: // Not final, inherited from for optimizations. class ExpressionSetIsSubset : public ExpressionFixedArity { public: + explicit ExpressionSetIsSubset(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + boost::intrusive_ptr optimize() override; Value evaluateInternal(Variables* vars) const override; const char* getOpName() const final; @@ -1190,7 +1360,9 @@ private: class ExpressionSetUnion final : public ExpressionVariadic { public: - // intrusive_ptr optimize() final; + explicit ExpressionSetUnion(const boost::intrusive_ptr& expCtx) + : ExpressionVariadic(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1206,6 +1378,9 @@ public: class ExpressionSize final : public ExpressionFixedArity { public: + explicit ExpressionSize(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1213,6 +1388,9 @@ public: class ExpressionReverseArray final : public ExpressionFixedArity { public: + explicit ExpressionReverseArray(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1220,6 +1398,9 @@ public: class ExpressionSlice final : public ExpressionRangedArity { public: + explicit ExpressionSlice(const boost::intrusive_ptr& expCtx) + : ExpressionRangedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1227,6 +1408,9 @@ public: class ExpressionIsArray final : public ExpressionFixedArity { public: + explicit ExpressionIsArray(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1234,12 +1418,19 @@ public: class ExpressionSplit final : public ExpressionFixedArity { public: + explicit ExpressionSplit(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionSqrt final : public ExpressionSingleNumericArg { +public: + explicit ExpressionSqrt(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1247,6 +1438,9 @@ class ExpressionSqrt final : public ExpressionSingleNumericArg { class ExpressionStrcasecmp final : public ExpressionFixedArity { public: + explicit ExpressionStrcasecmp(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1254,6 +1448,9 @@ public: class ExpressionSubstrBytes : public ExpressionFixedArity { public: + explicit ExpressionSubstrBytes(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const; }; @@ -1261,18 +1458,29 @@ public: class ExpressionSubstrCP final : public ExpressionFixedArity { public: + explicit ExpressionSubstrCP(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenBytes final : public ExpressionFixedArity { +public: + explicit ExpressionStrLenBytes(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; class ExpressionStrLenCP final : public ExpressionFixedArity { +public: + explicit ExpressionStrLenCP(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1280,6 +1488,9 @@ class ExpressionStrLenCP final : public ExpressionFixedArity { public: + explicit ExpressionSubtract(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1287,15 +1498,18 @@ public: class ExpressionSwitch final : public Expression { public: + explicit ExpressionSwitch(const boost::intrusive_ptr& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr optimize() final; - static boost::intrusive_ptr parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: using ExpressionPair = std::pair, boost::intrusive_ptr>; @@ -1307,6 +1521,9 @@ private: class ExpressionToLower final : public ExpressionFixedArity { public: + explicit ExpressionToLower(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1314,6 +1531,9 @@ public: class ExpressionToUpper final : public ExpressionFixedArity { public: + explicit ExpressionToUpper(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1321,6 +1541,9 @@ public: class ExpressionTrunc final : public ExpressionSingleNumericArg { public: + explicit ExpressionTrunc(const boost::intrusive_ptr& expCtx) + : ExpressionSingleNumericArg(expCtx) {} + Value evaluateNumericArg(const Value& numericArg) const final; const char* getOpName() const final; }; @@ -1328,6 +1551,9 @@ public: class ExpressionType final : public ExpressionFixedArity { public: + explicit ExpressionType(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; }; @@ -1335,6 +1561,9 @@ public: class ExpressionWeek final : public ExpressionFixedArity { public: + explicit ExpressionWeek(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1344,6 +1573,9 @@ public: class ExpressionIsoWeekYear final : public ExpressionFixedArity { public: + explicit ExpressionIsoWeekYear(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1353,6 +1585,9 @@ public: class ExpressionIsoDayOfWeek final : public ExpressionFixedArity { public: + explicit ExpressionIsoDayOfWeek(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1362,6 +1597,9 @@ public: class ExpressionIsoWeek final : public ExpressionFixedArity { public: + explicit ExpressionIsoWeek(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1371,6 +1609,9 @@ public: class ExpressionYear final : public ExpressionFixedArity { public: + explicit ExpressionYear(const boost::intrusive_ptr& expCtx) + : ExpressionFixedArity(expCtx) {} + Value evaluateInternal(Variables* vars) const final; const char* getOpName() const final; @@ -1383,15 +1624,18 @@ public: class ExpressionZip final : public Expression { public: + explicit ExpressionZip(const boost::intrusive_ptr& expCtx) + : Expression(expCtx) {} + void addDependencies(DepsTracker* deps) const final; Value evaluateInternal(Variables* vars) const final; boost::intrusive_ptr optimize() final; - static boost::intrusive_ptr parse(BSONElement expr, - const VariablesParseState& vpsIn); + static boost::intrusive_ptr parse( + const boost::intrusive_ptr& expCtx, + BSONElement expr, + const VariablesParseState& vpsIn); Value serialize(bool explain) const final; - void doInjectExpressionContext() final; - private: bool _useLongestLength = false; ExpressionVector _inputs; diff --git a/src/mongo/db/pipeline/expression_context.cpp b/src/mongo/db/pipeline/expression_context.cpp index 865bf08ec6c..5bccef201ae 100644 --- a/src/mongo/db/pipeline/expression_context.cpp +++ b/src/mongo/db/pipeline/expression_context.cpp @@ -39,27 +39,27 @@ ExpressionContext::ResolvedNamespace::ResolvedNamespace(NamespaceString ns, std::vector pipeline) : ns(std::move(ns)), pipeline(std::move(pipeline)) {} -ExpressionContext::ExpressionContext(OperationContext* opCtx, const AggregationRequest& request) +ExpressionContext::ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr collator, + StringMap resolvedNamespaces) : isExplain(request.isExplain()), inShard(request.isFromRouter()), extSortAllowed(request.shouldAllowDiskUse()), bypassDocumentValidation(request.shouldBypassDocumentValidation()), ns(request.getNamespaceString()), opCtx(opCtx), - collation(request.getCollation()) { - if (!collation.isEmpty()) { - auto statusWithCollator = - CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation); - uassertStatusOK(statusWithCollator.getStatus()); - setCollator(std::move(statusWithCollator.getValue())); - } -} + collation(request.getCollation()), + _collator(std::move(collator)), + _documentComparator(_collator.get()), + _valueComparator(_collator.get()), + _resolvedNamespaces(std::move(resolvedNamespaces)) {} void ExpressionContext::checkForInterrupt() { // This check could be expensive, at least in relative terms, so don't check every time. - if (--interruptCounter == 0) { + if (--_interruptCounter == 0) { opCtx->checkForInterrupt(); - interruptCounter = kInterruptCheckPeriod; + _interruptCounter = kInterruptCheckPeriod; } } @@ -90,9 +90,9 @@ intrusive_ptr ExpressionContext::copyWith(NamespaceString ns) expCtx->setCollator(_collator->clone()); } - expCtx->resolvedNamespaces = resolvedNamespaces; + expCtx->_resolvedNamespaces = _resolvedNamespaces; - // Note that we intentionally skip copying the value of 'interruptCounter' because 'expCtx' is + // Note that we intentionally skip copying the value of '_interruptCounter' because 'expCtx' is // intended to be used for executing a separate aggregation pipeline. return expCtx; diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h index 67114fa52b0..af6d7d199c9 100644 --- a/src/mongo/db/pipeline/expression_context.h +++ b/src/mongo/db/pipeline/expression_context.h @@ -45,7 +45,7 @@ namespace mongo { -struct ExpressionContext : public IntrusiveCounterUnsigned { +class ExpressionContext : public RefCountable { public: struct ResolvedNamespace { ResolvedNamespace() = default; @@ -55,9 +55,14 @@ public: std::vector pipeline; }; - ExpressionContext() = default; - - ExpressionContext(OperationContext* opCtx, const AggregationRequest& request); + /** + * Constructs an ExpressionContext to be used for Pipeline parsing and evaluation. + * 'resolvedNamespaces' maps collection names (not full namespaces) to ResolvedNamespaces. + */ + ExpressionContext(OperationContext* opCtx, + const AggregationRequest& request, + std::unique_ptr collator, + StringMap resolvedNamespaces); /** * Used by a pipeline to check for interrupts so that killOp() works. Throws a UserAssertion if @@ -65,8 +70,6 @@ public: */ void checkForInterrupt(); - void setCollator(std::unique_ptr coll); - const CollatorInterface* getCollator() const { return _collator.get(); } @@ -85,6 +88,16 @@ public: */ boost::intrusive_ptr copyWith(NamespaceString ns) const; + /** + * Returns the ResolvedNamespace corresponding to 'nss'. It is an error to call this method on a + * namespace not involved in the pipeline. + */ + const ResolvedNamespace& getResolvedNamespace(const NamespaceString& nss) const { + auto it = _resolvedNamespaces.find(nss.coll()); + invariant(it != _resolvedNamespaces.end()); + return it->second; + }; + bool isExplain = false; bool inShard = false; bool inRouter = false; @@ -100,19 +113,34 @@ public: // collation. BSONObj collation; - StringMap resolvedNamespaces; - +protected: static const int kInterruptCheckPeriod = 128; - int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus -private: - // Collator used to compare elements. 'collator' is initialized from 'collation', except in the - // case where 'collation' is empty and there is a collection default collation. + /** + * Should only be used by 'ExpressionContextForTest'. + */ + ExpressionContext() = default; + + /** + * Sets '_collator' and resets '_documentComparator' and '_valueComparator'. + * + * Use with caution - it is illegal to change the collation once a Pipeline has been parsed with + * this ExpressionContext. + */ + void setCollator(std::unique_ptr collator); + + // Collator used for comparisons. std::unique_ptr _collator; // Used for all comparisons of Document/Value during execution of the aggregation operation. + // Must not be changed after parsing a Pipeline with this ExpressionContext. DocumentComparator _documentComparator; ValueComparator _valueComparator; + + // A map from namespace to the resolved namespace, in case any views are involved. + StringMap _resolvedNamespaces; + + int _interruptCounter = kInterruptCheckPeriod; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_context_for_test.h b/src/mongo/db/pipeline/expression_context_for_test.h new file mode 100644 index 00000000000..d093cfe2bfa --- /dev/null +++ b/src/mongo/db/pipeline/expression_context_for_test.h @@ -0,0 +1,64 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects + * for all of the code used other than as permitted herein. If you modify + * file(s) with this exception, you may extend this exception to your + * version of the file(s), but you are not obligated to do so. If you do not + * wish to do so, delete this exception statement from your version. If you + * delete this exception statement from all source files in the program, + * then also delete it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/expression_context.h" + +namespace mongo { + +/** + * An ExpressionContext that can have state like the collation and resolved namespace map + * manipulated after construction. In contrast, a regular ExpressionContext requires the collation + * and resolved namespaces to be provided on construction and does not allow them to be subsequently + * mutated. + */ +class ExpressionContextForTest : public ExpressionContext { +public: + ExpressionContextForTest() = default; + + ExpressionContextForTest(OperationContext* txn, const AggregationRequest& request) + : ExpressionContext(txn, request, nullptr, {}) {} + + /** + * Changes the collation used by this ExpressionContext. Must not be changed after parsing a + * Pipeline with this ExpressionContext. + */ + void setCollator(std::unique_ptr collator) { + ExpressionContext::setCollator(std::move(collator)); + } + + /** + * Sets the resolved definition for an involved namespace. + */ + void setResolvedNamespace(const NamespaceString& nss, ResolvedNamespace resolvedNamespace) { + _resolvedNamespaces[nss.coll()] = std::move(resolvedNamespace); + } +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 13e96dfef0c..315a2faeac7 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -36,7 +36,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value_comparator.h" #include "mongo/dbtests/dbtests.h" #include "mongo/unittest/unittest.h" @@ -62,12 +62,11 @@ static void assertExpectedResults(string expression, initializer_list, Value>> operations) { for (auto&& op : operations) { try { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const BSONObj obj = BSON(expression << Value(op.first)); - auto expression = Expression::parseExpression(obj, vps); - expression->injectExpressionContext(expCtx); + auto expression = Expression::parseExpression(expCtx, obj, vps); Value result = expression->evaluate(Document()); ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); @@ -190,7 +189,10 @@ public: private: Testable(bool isAssociative, bool isCommutative) - : _isAssociative(isAssociative), _isCommutative(isCommutative) {} + : ExpressionNary( + boost::intrusive_ptr(new ExpressionContextForTest())), + _isAssociative(isAssociative), + _isCommutative(isCommutative) {} bool _isAssociative; bool _isCommutative; }; @@ -224,12 +226,13 @@ protected: } void addOperandArrayToExpr(const intrusive_ptr& expr, const BSONArray& operands) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); BSONObjIterator i(operands); while (i.more()) { BSONElement element = i.next(); - expr->addOperand(Expression::parseOperand(element, vps)); + expr->addOperand(Expression::parseOperand(expCtx, element, vps)); } } @@ -244,7 +247,7 @@ TEST_F(ExpressionNaryTest, AddedConstantOperandIsSerialized) { } TEST_F(ExpressionNaryTest, AddedFieldPathOperandIsSerialized) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertContents(_notAssociativeNorCommutative, BSON_ARRAY("$ab.c")); } @@ -258,7 +261,7 @@ TEST_F(ExpressionNaryTest, ValidateConstantExpressionDependency) { } TEST_F(ExpressionNaryTest, ValidateFieldPathExpressionDependency) { - _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create("ab.c")); + _notAssociativeNorCommutative->addOperand(ExpressionFieldPath::create(nullptr, "ab.c")); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("ab.c")); } @@ -267,10 +270,12 @@ TEST_F(ExpressionNaryTest, ValidateObjectExpressionDependency) { << "$x" << "q" << "$r")); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - _notAssociativeNorCommutative->addOperand(Expression::parseObject(specElement.Obj(), vps)); + _notAssociativeNorCommutative->addOperand( + Expression::parseObject(expCtx, specElement.Obj(), vps)); assertDependencies(_notAssociativeNorCommutative, BSON_ARRAY("r" << "x")); @@ -683,7 +688,8 @@ TEST_F(ExpressionNaryTest, FlattenInnerOperandsOptimizationOnCommutativeAndAssoc class ExpressionCeilTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionCeil(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + _expr = new ExpressionCeil(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -741,7 +747,8 @@ TEST_F(ExpressionCeilTest, NullArg) { class ExpressionFloorTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionFloor(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + _expr = new ExpressionFloor(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -856,7 +863,8 @@ TEST(ExpressionReverseArrayTest, ReturnsNullWithNullishInput) { class ExpressionTruncTest : public ExpressionNaryTestOneArg { public: virtual void assertEvaluates(Value input, Value output) override { - _expr = new ExpressionTrunc(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + _expr = new ExpressionTrunc(expCtx); ExpressionNaryTestOneArg::assertEvaluates(input, output); } }; @@ -917,7 +925,8 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expression = new ExpressionAdd(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = new ExpressionAdd(expCtx); populateOperands(expression); ASSERT_BSONOBJ_EQ(expectedResult(), toBson(expression->evaluate(Document()))); } @@ -932,7 +941,8 @@ protected: class NullDocument { public: void run() { - intrusive_ptr expression = new ExpressionAdd(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(2))); ASSERT_BSONOBJ_EQ(BSON("" << 2), toBson(expression->evaluate(Document()))); } @@ -950,7 +960,8 @@ class NoOperands : public ExpectedResultBase { class String { public: void run() { - intrusive_ptr expression = new ExpressionAdd(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value("a"_sd))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -960,7 +971,8 @@ public: class Bool { public: void run() { - intrusive_ptr expression = new ExpressionAdd(); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = new ExpressionAdd(expCtx); expression->addOperand(ExpressionConstant::create(nullptr, Value(true))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } @@ -1193,13 +1205,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -1217,13 +1228,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -1515,7 +1525,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr nested = ExpressionFieldPath::create("a.b"); + intrusive_ptr nested = ExpressionFieldPath::create(nullptr, "a.b"); intrusive_ptr expression = ExpressionCoerceToBool::create(nullptr, nested); DepsTracker dependencies; expression->addDependencies(&dependencies); @@ -1531,7 +1541,7 @@ class AddToBsonObj { public: void run() { intrusive_ptr expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(fromjson("{field:{$and:['$foo']}}"), toBsonObj(expression)); @@ -1548,7 +1558,7 @@ class AddToBsonArray { public: void run() { intrusive_ptr expression = - ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create(nullptr, "foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(BSON_ARRAY(fromjson("{$and:['$foo']}")), toBsonArray(expression)); @@ -1577,9 +1587,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expCtx(new ExpressionContext()); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); intrusive_ptr optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(constify(expectedOptimized()), expressionToBson(optimized)); } @@ -1610,9 +1619,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expCtx(new ExpressionContext()); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); // Check expression spec round trip. ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); // Check evaluation result. @@ -1648,11 +1656,12 @@ class ParseError { public: virtual ~ParseError() {} void run() { + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - ASSERT_THROWS(Expression::parseOperand(specElement, vps), UserException); + ASSERT_THROWS(Expression::parseOperand(expCtx, specElement, vps), UserException); } protected: @@ -1855,9 +1864,8 @@ public: BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expCtx(new ExpressionContext()); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expCtx(new ExpressionContextForTest()); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true)); } }; @@ -2000,10 +2008,11 @@ public: void run() { BSONObj spec = BSON("IGNORED_FIELD_NAME" << "foo"); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONElement specElement = spec.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = ExpressionConstant::parse(specElement, vps); + intrusive_ptr expression = ExpressionConstant::parse(expCtx, specElement, vps); assertBinaryEqual(BSON("" << "foo"), toBson(expression->evaluate(Document()))); @@ -2140,7 +2149,7 @@ namespace FieldPath { class Invalid { public: void run() { - ASSERT_THROWS(ExpressionFieldPath::create(""), UserException); + ASSERT_THROWS(ExpressionFieldPath::create(nullptr, ""), UserException); } }; @@ -2148,7 +2157,7 @@ public: class Optimize { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a"); // An attempt to optimize returns the Expression itself. ASSERT_EQUALS(expression, expression->optimize()); } @@ -2158,7 +2167,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); DepsTracker dependencies; expression->addDependencies(&dependencies); ASSERT_EQUALS(1U, dependencies.fields.size()); @@ -2172,7 +2181,7 @@ public: class Missing { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(Document()))); } }; @@ -2181,7 +2190,7 @@ public: class Present { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a"); assertBinaryEqual(fromjson("{'':123}"), toBson(expression->evaluate(fromBson(BSON("a" << 123))))); } @@ -2191,7 +2200,7 @@ public: class NestedBelowNull { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:null}"))))); } @@ -2201,7 +2210,7 @@ public: class NestedBelowUndefined { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{a:undefined}"))))); } @@ -2211,7 +2220,7 @@ public: class NestedBelowMissing { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(fromjson("{z:1}"))))); } @@ -2221,7 +2230,7 @@ public: class NestedBelowInt { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << 2))))); } }; @@ -2230,7 +2239,7 @@ public: class NestedValue { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << 55), toBson(expression->evaluate(fromBson(BSON("a" << BSON("b" << 55)))))); } @@ -2240,7 +2249,7 @@ public: class NestedBelowEmptyObject { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{}"), toBson(expression->evaluate(fromBson(BSON("a" << BSONObj()))))); } @@ -2250,7 +2259,7 @@ public: class NestedBelowEmptyArray { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(BSON("" << BSONArray()), toBson(expression->evaluate(fromBson(BSON("a" << BSONArray()))))); } @@ -2260,7 +2269,7 @@ public: class NestedBelowArrayWithNull { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[null]}"))))); } @@ -2270,7 +2279,7 @@ public: class NestedBelowArrayWithUndefined { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[undefined]}"))))); } @@ -2280,7 +2289,7 @@ public: class NestedBelowArrayWithInt { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[1]}"))))); } @@ -2290,7 +2299,7 @@ public: class NestedWithinArray { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:9}]}"))))); } @@ -2300,7 +2309,7 @@ public: class MultipleArrayValues { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b"); assertBinaryEqual(fromjson("{'':[9,20]}"), toBson(expression->evaluate( fromBson(fromjson("{a:[{b:9},null,undefined,{g:4},{b:20},{}]}"))))); @@ -2311,7 +2320,7 @@ public: class ExpandNestedArrays { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(fromjson("{'':[[1,2],3,[4],[[5]],[6,7]]}"), toBson(expression->evaluate(fromBson(fromjson("{a:[{b:[{c:1},{c:2}]}," "{b:{c:3}}," @@ -2325,7 +2334,7 @@ public: class AddToBsonObj { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b.c"); assertBinaryEqual(BSON("foo" << "$a.b.c"), BSON("foo" << expression->serialize(false))); @@ -2336,7 +2345,7 @@ public: class AddToBsonArray { public: void run() { - intrusive_ptr expression = ExpressionFieldPath::create("a.b.c"); + intrusive_ptr expression = ExpressionFieldPath::create(nullptr, "a.b.c"); BSONArrayBuilder bab; bab << expression->serialize(false); assertBinaryEqual(BSON_ARRAY("$a.b.c"), bab.arr()); @@ -2358,16 +2367,19 @@ Document literal(T&& value) { // TEST(ExpressionObjectParse, ShouldAcceptEmptyObject) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSONObj(), vps); + auto object = ExpressionObject::parse(expCtx, BSONObj(), vps); ASSERT_VALUE_EQ(Value(Document{}), object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << 5 << "b" + auto object = ExpressionObject::parse(expCtx, + BSON("a" << 5 << "b" << "string" << "c" << BSONNULL), @@ -2378,25 +2390,29 @@ TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { } TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("_id" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("_id" << 5), vps); auto expectedResult = Value(Document{{"_id", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a$b" << 5), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a$b" << 5), vps); auto expectedResult = Value(Document{{"a$b", literal(5)}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); + auto object = + ExpressionObject::parse(expCtx, fromjson("{a: {b: 1}, c: {d: {e: 1, f: 1}}}"), vps); auto expectedResult = Value(Document{{"a", Document{{"b", literal(1)}}}, {"c", Document{{"d", Document{{"e", literal(1)}, {"f", literal(1)}}}}}}); @@ -2404,18 +2420,20 @@ TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { } TEST(ExpressionObjectParse, ShouldAcceptArrays) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(fromjson("{a: [1, 2]}"), vps); + auto object = ExpressionObject::parse(expCtx, fromjson("{a: [1, 2]}"), vps); auto expectedResult = Value(Document{{"a", vector{Value(literal(1)), Value(literal(2))}}}); ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - auto object = ExpressionObject::parse(BSON("a" << BSON("$and" << BSONArray())), vps); + auto object = ExpressionObject::parse(expCtx, BSON("a" << BSON("$and" << BSONArray())), vps); ASSERT_VALUE_EQ(object->serialize(false), Value(Document{{"a", Document{{"$and", BSONArray()}}}})); } @@ -2425,48 +2443,60 @@ TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { // TEST(ExpressionObjectParse, ShouldRejectDottedFieldNames) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("c" << 3 << "a.b" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a.b" << 1 << "c" << 3), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("c" << 3 << "a.b" << 1), vps), + UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a.b" << 1 << "c" << 3), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectDuplicateFieldNames) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), - UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << 1), vps), UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << BSON("c" << 1) << "b" << 2 << "a" << 1), vps), + UserException); + ASSERT_THROWS( + ExpressionObject::parse(expCtx, BSON("a" << 1 << "b" << 2 << "a" << BSON("c" << 1)), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldName) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("$a" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON("" << 1), vps), UserException); - ASSERT_THROWS(ExpressionObject::parse(BSON(std::string("a\0b", 3) << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("$a" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON("" << 1), vps), UserException); + ASSERT_THROWS(ExpressionObject::parse(expCtx, BSON(std::string("a\0b", 3) << 1), vps), + UserException); } TEST(ExpressionObjectParse, ShouldRejectInvalidFieldPathAsValue) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse(BSON("a" + ASSERT_THROWS(ExpressionObject::parse(expCtx, + BSON("a" << "$field."), vps), UserException); } TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); - ASSERT_THROWS(ExpressionObject::parse( - BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), - UserException); + ASSERT_THROWS( + ExpressionObject::parse( + expCtx, BSON("a" << BSON("$and" << BSONArray()) << "$or" << BSONArray()), vps), + UserException); } // @@ -2474,14 +2504,17 @@ TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { // TEST(ExpressionObjectEvaluate, EmptyObjectShouldEvaluateToEmptyDocument) { - auto object = ExpressionObject::create({}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {}); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { - auto object = ExpressionObject::create({{"a", makeConstant(1)}, {"b", makeConstant(5)}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", makeConstant(1)}, {"b", makeConstant(5)}}); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}})); ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), @@ -2489,36 +2522,45 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { } TEST(ExpressionObjectEvaluate, OrderOfFieldsInOutputShouldMatchOrderInSpecification) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("a")}, - {"b", ExpressionFieldPath::create("b")}, - {"c", ExpressionFieldPath::create("c")}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a")}, + {"b", ExpressionFieldPath::create(expCtx, "b")}, + {"c", ExpressionFieldPath::create(expCtx, "c")}}); ASSERT_VALUE_EQ( Value(Document{{"a", "A"_sd}, {"b", "B"_sd}, {"c", "C"_sd}}), object->evaluate(Document{{"c", "C"_sd}, {"a", "A"_sd}, {"b", "B"_sd}, {"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldRemoveFieldsThatHaveMissingValues) { - auto object = ExpressionObject::create( - {{"a", ExpressionFieldPath::create("a.b")}, {"b", ExpressionFieldPath::create("missing")}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, + {{"a", ExpressionFieldPath::create(expCtx, "a.b")}, + {"b", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document{{"a", 1}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) { + intrusive_ptr expCtx(new ExpressionContextForTest()); auto object = ExpressionObject::create( + expCtx, {{"a", ExpressionObject::create( - {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create("_id")}})}}); + expCtx, + {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create(expCtx, "_id")}})}}); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document())); ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"_sd}}}}), object->evaluate(Document{{"_id", "ID"_sd}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissing) { - auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("missing")}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"a", ExpressionFieldPath::create(expCtx, "missing")}}); ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); - auto objectWithNestedObject = ExpressionObject::create({{"nested", object}}); + auto objectWithNestedObject = ExpressionObject::create(expCtx, {{"nested", object}}); ASSERT_VALUE_EQ(Value(Document{{"nested", Document{}}}), objectWithNestedObject->evaluate(Document())); } @@ -2528,14 +2570,17 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissin // TEST(ExpressionObjectDependencies, ConstantValuesShouldNotBeAddedToDependencies) { - auto object = ExpressionObject::create({{"a", makeConstant(5)}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = ExpressionObject::create(expCtx, {{"a", makeConstant(5)}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 0UL); } TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { - auto object = ExpressionObject::create({{"x", ExpressionFieldPath::create("c.d")}}); + intrusive_ptr expCtx(new ExpressionContextForTest()); + auto object = + ExpressionObject::create(expCtx, {{"x", ExpressionFieldPath::create(expCtx, "c.d")}}); DepsTracker deps; object->addDependencies(&deps); ASSERT_EQ(deps.fields.size(), 1UL); @@ -2548,11 +2593,12 @@ TEST(ExpressionObjectDependencies, FieldPathsShouldBeAddedToDependencies) { TEST(ExpressionObjectOptimizations, OptimizingAnObjectShouldOptimizeSubExpressions) { // Build up the object {a: {$add: [1, 2]}}. + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGen; VariablesParseState vps(&idGen); auto addExpression = - ExpressionAdd::parse(BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); - auto object = ExpressionObject::create({{"a", addExpression}}); + ExpressionAdd::parse(expCtx, BSON("$add" << BSON_ARRAY(1 << 2)).firstElement(), vps); + auto object = ExpressionObject::create(expCtx, {{"a", addExpression}}); ASSERT_EQ(object->getChildExpressions().size(), 1UL); auto optimized = object->optimize(); @@ -2575,11 +2621,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -2597,11 +2644,12 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); intrusive_ptr optimized = expression->optimize(); ASSERT_BSONOBJ_EQ(expectedOptimized(), expressionToBson(optimized)); @@ -2876,10 +2924,11 @@ namespace Object { * Parses the object given by 'specification', with the options given by 'parseContextOptions'. */ boost::intrusive_ptr parseObject(BSONObj specification) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseObject(specification, vps); + return Expression::parseObject(expCtx, specification, vps); }; TEST(ParseObject, ShouldAcceptEmptyObject) { @@ -2910,9 +2959,10 @@ using mongo::Expression; * Parses an expression from the given BSON specification. */ boost::intrusive_ptr parseExpression(BSONObj specification) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseExpression(specification, vps); + return Expression::parseExpression(expCtx, specification, vps); } TEST(ParseExpression, ShouldRecognizeConstExpression) { @@ -3016,10 +3066,11 @@ using mongo::Expression; * case the field name would be the name of the expression. */ intrusive_ptr parseOperand(BSONObj specification) { + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONElement specElement = specification.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - return Expression::parseOperand(specElement, vps); + return Expression::parseOperand(expCtx, specElement, vps); } TEST(ParseOperand, ShouldRecognizeFieldPath) { @@ -3081,7 +3132,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3092,8 +3143,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr expr = + Expression::parseExpression(expCtx, obj, vps); Value result = expr->evaluate(Document()); if (result.getType() == Array) { result = sortSet(result); @@ -3121,8 +3172,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); @@ -3399,13 +3449,12 @@ private: return BSON("$strcasecmp" << BSON_ARRAY(b() << a())); } void assertResult(int expectedResult, const BSONObj& spec) { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult), toBson(expression->evaluate(Document()))); } @@ -3527,13 +3576,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3635,22 +3683,24 @@ class DropEndingNull : public ExpectedResultBase { namespace SubstrCP { TEST(ExpressionSubstrCPTest, DoesThrowWithBadContinuationByte) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto continuationByte = "\x80\x00"_sd; const auto expr = Expression::parseExpression( - BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); + expCtx, BSON("$substrCP" << BSON_ARRAY(continuationByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } TEST(ExpressionSubstrCPTest, DoesThrowWithInvalidLeadingByte) { + intrusive_ptr expCtx(new ExpressionContextForTest()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const auto leadingByte = "\xFF\x00"_sd; - const auto expr = - Expression::parseExpression(BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); + const auto expr = Expression::parseExpression( + expCtx, BSON("$substrCP" << BSON_ARRAY(leadingByte << 0 << 1)), vps); ASSERT_THROWS({ expr->evaluate(Document()); }, UserException); } @@ -3792,13 +3842,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3851,13 +3900,12 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - intrusive_ptr expression = Expression::parseOperand(specElement, vps); - expression->injectExpressionContext(expCtx); + intrusive_ptr expression = Expression::parseOperand(expCtx, specElement, vps); ASSERT_BSONOBJ_EQ(constify(spec()), expressionToBson(expression)); ASSERT_BSONOBJ_EQ(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3909,7 +3957,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { - intrusive_ptr expCtx(new ExpressionContext()); + intrusive_ptr expCtx(new ExpressionContextForTest()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3920,8 +3968,8 @@ public: const BSONObj obj = BSON(field.first << args); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); - const intrusive_ptr expr = Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + const intrusive_ptr expr = + Expression::parseExpression(expCtx, obj, vps); const Value result = expr->evaluate(Document()); if (ValueComparator().evaluate(result != expected)) { string errMsg = str::stream() @@ -3946,8 +3994,7 @@ public: // NOTE: parse and evaluatation failures are treated the // same const intrusive_ptr expr = - Expression::parseExpression(obj, vps); - expr->injectExpressionContext(expCtx); + Expression::parseExpression(expCtx, obj, vps); expr->evaluate(Document()); }, UserException); diff --git a/src/mongo/db/pipeline/granularity_rounder.cpp b/src/mongo/db/pipeline/granularity_rounder.cpp index 0941c8b36fa..6071c6644db 100644 --- a/src/mongo/db/pipeline/granularity_rounder.cpp +++ b/src/mongo/db/pipeline/granularity_rounder.cpp @@ -50,11 +50,11 @@ void GranularityRounder::registerGranularityRounder(StringData name, Rounder rou } boost::intrusive_ptr GranularityRounder::getGranularityRounder( - StringData granularity) { + const boost::intrusive_ptr& expCtx, StringData granularity) { auto it = rounderMap.find(granularity); uassert(40257, str::stream() << "Unknown rounding granularity '" << granularity << "'", it != rounderMap.end()); - return it->second(); + return it->second(expCtx); } } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder.h b/src/mongo/db/pipeline/granularity_rounder.h index 862c604dcc0..24a8f53de97 100644 --- a/src/mongo/db/pipeline/granularity_rounder.h +++ b/src/mongo/db/pipeline/granularity_rounder.h @@ -32,6 +32,7 @@ #include "mongo/base/init.h" #include "mongo/db/jsobj.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" #include "mongo/util/intrusive_counter.h" @@ -73,7 +74,8 @@ namespace mongo { */ class GranularityRounder : public RefCountable { public: - using Rounder = stdx::function()>; + using Rounder = stdx::function( + const boost::intrusive_ptr&)>; /** * Registers a GranularityRounder with a parsing function so that when a granularity @@ -89,7 +91,8 @@ public: * Retrieves the GranularityRounder for the granularity given by 'granularity', and raises an * error if there is no such granularity registered. */ - static boost::intrusive_ptr getGranularityRounder(StringData granularity); + static boost::intrusive_ptr getGranularityRounder( + const boost::intrusive_ptr& expCtx, StringData granularity); /** * Rounds up 'value' to the first value greater than 'value' in the granularity series. If @@ -109,6 +112,16 @@ public: * Returns the name of the granularity series that the GranularityRounder is using for rounding. */ virtual std::string getName() = 0; + +protected: + GranularityRounder(const boost::intrusive_ptr& expCtx) : _expCtx(expCtx) {} + + ExpressionContext* getExpCtx() { + return _expCtx.get(); + } + +private: + boost::intrusive_ptr _expCtx; }; /** @@ -124,8 +137,10 @@ public: * 'baseSeries'. This method requires that baseSeries has at least 2 numbers and is in sorted * order. */ - static boost::intrusive_ptr create(const std::vector baseSeries, - std::string name); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx, + const std::vector baseSeries, + std::string name); Value roundUp(Value value); Value roundDown(Value value); @@ -138,7 +153,9 @@ public: const std::vector getSeries() const; private: - GranularityRounderPreferredNumbers(std::vector baseSeries, std::string name); + GranularityRounderPreferredNumbers(const boost::intrusive_ptr& expCtx, + std::vector baseSeries, + std::string name); // '_baseSeries' is the preferred number series that is used for rounding. A preferred numbers // series is infinite, but we represent it with a finite vector of numbers. When rounding, we @@ -153,13 +170,16 @@ private: */ class GranularityRounderPowersOfTwo final : public GranularityRounder { public: - static boost::intrusive_ptr create(); + static boost::intrusive_ptr create( + const boost::intrusive_ptr& expCtx); Value roundUp(Value value); Value roundDown(Value value); std::string getName(); private: - GranularityRounderPowersOfTwo() = default; + GranularityRounderPowersOfTwo(const boost::intrusive_ptr& expCtx) + : GranularityRounder(expCtx) {} + std::string _name = "POWERSOF2"; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp index 56d3ae02267..078936e8145 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two.cpp @@ -40,8 +40,9 @@ using std::string; REGISTER_GRANULARITY_ROUNDER(POWERSOF2, GranularityRounderPowersOfTwo::create); -intrusive_ptr GranularityRounderPowersOfTwo::create() { - return new GranularityRounderPowersOfTwo(); +intrusive_ptr GranularityRounderPowersOfTwo::create( + const boost::intrusive_ptr& expCtx) { + return new GranularityRounderPowersOfTwo(expCtx); } namespace { @@ -80,7 +81,7 @@ Value GranularityRounderPowersOfTwo::roundUp(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } Value GranularityRounderPowersOfTwo::roundDown(Value value) { @@ -113,7 +114,7 @@ Value GranularityRounderPowersOfTwo::roundDown(Value value) { } Variables vars; - return ExpressionPow::create(Value(2), exp)->evaluate(&vars); + return ExpressionPow::create(getExpCtx(), Value(2), exp)->evaluate(&vars); } string GranularityRounderPowersOfTwo::getName() { diff --git a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp index adb7046c0c1..eefcdb749e5 100644 --- a/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_powers_of_two_test.cpp @@ -31,6 +31,7 @@ #include "mongo/db/pipeline/granularity_rounder.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -50,7 +51,8 @@ void testEquals(Value actual, Value expected, double delta = DELTA) { } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0.5)), Value(1)); testEquals(rounder->roundUp(Value(1)), Value(2)); @@ -64,7 +66,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpPowersOfTwoToNextPowerOfTwo } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 4611686018427387905; // 2^62 + 1 double output = 9223372036854775808.0; // 2^63 @@ -78,7 +81,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnDoubleIfExceedsNumberLong) { } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); int input = 1073741824; // 2^30 long long output = 2147483648; // 2^31 @@ -92,7 +96,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfExceedsNumberInt } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubleIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); double input = 9223372036854775808.0; // 2^63 long long output = 4611686018427387904; // 2^62 @@ -106,7 +111,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberLongIfRoundedDownDoubl } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumberLongIsSmallEnough) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); long long input = 2147483648; // 2^31 int output = 1073741824; // 2^30 @@ -120,7 +126,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberIntIfRoundedDownNumber } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.12); Decimal128 output = Decimal128(0.125); @@ -134,7 +141,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingUpN } TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDownNumberDecimal) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Decimal128 input = Decimal128(0.13); Decimal128 output = Decimal128(0.125); @@ -148,7 +156,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldReturnNumberDecimalWhenRoundingDow } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(3)), Value(4)); testEquals(rounder->roundUp(Value(5)), Value(8)); @@ -168,7 +177,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundUpNonPowersOfTwoToNextPowerOf } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(16)), Value(8)); testEquals(rounder->roundDown(Value(8)), Value(4)); @@ -183,7 +193,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownPowersOfTwoToNextPowerOfT } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPowerOfTwo) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundDown(Value(10)), Value(8)); testEquals(rounder->roundDown(Value(9)), Value(8)); @@ -201,14 +212,16 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldRoundDownNonPowersOfTwoToNextPower } TEST(GranularityRounderPowersOfTwoTest, ShouldRoundZeroToZero) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); testEquals(rounder->roundUp(Value(0)), Value(0)); testEquals(rounder->roundDown(Value(0)), Value(0)); } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -217,7 +230,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNonNumericValues) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value nan = Value(std::nan("NaN")); ASSERT_THROWS_CODE(rounder->roundUp(nan), UserException, 40266); @@ -232,7 +246,8 @@ TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNaN) { } TEST(GranularityRounderPowersOfTwoTest, ShouldFailOnRoundingNegativeNumber) { - auto rounder = GranularityRounder::getGranularityRounder("POWERSOF2"); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), "POWERSOF2"); Value negativeNumber = Value(-1); ASSERT_THROWS_CODE(rounder->roundUp(negativeNumber), UserException, 40267); diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp index 81a845bed6e..5b62d435a14 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers.cpp @@ -95,56 +95,57 @@ const vector e192Series{ } // namespace // Register the GranularityRounders for the Renard number series. -REGISTER_GRANULARITY_ROUNDER(R5, []() { - return GranularityRounderPreferredNumbers::create(r5Series, "R5"); +REGISTER_GRANULARITY_ROUNDER(R5, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r5Series, "R5"); }); -REGISTER_GRANULARITY_ROUNDER(R10, []() { - return GranularityRounderPreferredNumbers::create(r10Series, "R10"); +REGISTER_GRANULARITY_ROUNDER(R10, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r10Series, "R10"); }); -REGISTER_GRANULARITY_ROUNDER(R20, []() { - return GranularityRounderPreferredNumbers::create(r20Series, "R20"); +REGISTER_GRANULARITY_ROUNDER(R20, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r20Series, "R20"); }); -REGISTER_GRANULARITY_ROUNDER(R40, []() { - return GranularityRounderPreferredNumbers::create(r40Series, "R40"); +REGISTER_GRANULARITY_ROUNDER(R40, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r40Series, "R40"); }); -REGISTER_GRANULARITY_ROUNDER(R80, []() { - return GranularityRounderPreferredNumbers::create(r80Series, "R80"); +REGISTER_GRANULARITY_ROUNDER(R80, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, r80Series, "R80"); }); -REGISTER_GRANULARITY_ROUNDER_GENERAL("1-2-5", 1_2_5, []() { - return GranularityRounderPreferredNumbers::create(series125, "1-2-5"); -}); +REGISTER_GRANULARITY_ROUNDER_GENERAL( + "1-2-5", 1_2_5, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, series125, "1-2-5"); + }); // Register the GranularityRounders for the E series. -REGISTER_GRANULARITY_ROUNDER(E6, []() { - return GranularityRounderPreferredNumbers::create(e6Series, "E6"); +REGISTER_GRANULARITY_ROUNDER(E6, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e6Series, "E6"); }); -REGISTER_GRANULARITY_ROUNDER(E12, []() { - return GranularityRounderPreferredNumbers::create(e12Series, "E12"); +REGISTER_GRANULARITY_ROUNDER(E12, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e12Series, "E12"); }); -REGISTER_GRANULARITY_ROUNDER(E24, []() { - return GranularityRounderPreferredNumbers::create(e24Series, "E24"); +REGISTER_GRANULARITY_ROUNDER(E24, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e24Series, "E24"); }); -REGISTER_GRANULARITY_ROUNDER(E48, []() { - return GranularityRounderPreferredNumbers::create(e48Series, "E48"); +REGISTER_GRANULARITY_ROUNDER(E48, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e48Series, "E48"); }); -REGISTER_GRANULARITY_ROUNDER(E96, []() { - return GranularityRounderPreferredNumbers::create(e96Series, "E96"); +REGISTER_GRANULARITY_ROUNDER(E96, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e96Series, "E96"); }); -REGISTER_GRANULARITY_ROUNDER(E192, []() { - return GranularityRounderPreferredNumbers::create(e192Series, "E192"); +REGISTER_GRANULARITY_ROUNDER(E192, [](const boost::intrusive_ptr& expCtx) { + return GranularityRounderPreferredNumbers::create(expCtx, e192Series, "E192"); }); -GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers(vector baseSeries, - string name) - : _baseSeries(baseSeries), _name(name) { +GranularityRounderPreferredNumbers::GranularityRounderPreferredNumbers( + const boost::intrusive_ptr& expCtx, vector baseSeries, string name) + : GranularityRounder(expCtx), _baseSeries(baseSeries), _name(name) { invariant(_baseSeries.size() > 1); invariant(std::is_sorted(_baseSeries.begin(), _baseSeries.end())); } intrusive_ptr GranularityRounderPreferredNumbers::create( - vector baseSeries, string name) { - return new GranularityRounderPreferredNumbers(baseSeries, name); + const boost::intrusive_ptr& expCtx, vector baseSeries, string name) { + return new GranularityRounderPreferredNumbers(expCtx, baseSeries, name); } namespace { diff --git a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp index 17d10d07977..89df39621ad 100644 --- a/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp +++ b/src/mongo/db/pipeline/granularity_rounder_preferred_numbers_test.cpp @@ -32,6 +32,7 @@ #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/util/assert_util.h" namespace mongo { @@ -461,7 +462,8 @@ void testSeriesWrappingAroundDecimal(intrusive_ptr rounder) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpInSeries(rounder); testRoundingUpInSeriesDecimal(rounder); @@ -471,7 +473,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpNumberInSeriesToNextNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownNumberInSeriesToPreviousNumberInSeries) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownInSeries(rounder); testRoundingDownInSeriesDecimal(rounder); @@ -480,7 +483,8 @@ TEST(GranularityRounderPreferredNumbersTest, TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingUpBetweenSeries(rounder); testRoundingUpBetweenSeriesDecimal(rounder); @@ -489,7 +493,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundUpValueInBetweenSeriesNu TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeriesNumbers) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testRoundingDownBetweenSeries(rounder); testRoundingDownBetweenSeriesDecimal(rounder); @@ -498,7 +503,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundDownValueInBetweenSeries TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); testSeriesWrappingAround(rounder); testSeriesWrappingAroundDecimal(rounder); @@ -507,7 +513,8 @@ TEST(GranularityRounderPreferredNumbersTest, SeriesShouldWrapAroundWhenRounding) TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder rounds zero to zero. testEquals(rounder->roundUp(Value(0)), Value(0)); @@ -520,7 +527,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldRoundZeroToZero) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValues) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a non-numeric value. Value stringValue = Value("test"_sd); @@ -531,7 +539,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNonNumericValue TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding NaN. Value nan = Value(std::nan("NaN")); @@ -549,7 +558,8 @@ TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNaN) { TEST(GranularityRounderPreferredNumbersTest, ShouldFailOnRoundingNegativeNumber) { for (auto&& series : preferredNumberSeries) { - auto rounder = GranularityRounder::getGranularityRounder(series); + auto rounder = + GranularityRounder::getGranularityRounder(new ExpressionContextForTest(), series); // Make sure that each GranularityRounder fails when rounding a negative number. Value negativeNumber = Value(-1); diff --git a/src/mongo/db/pipeline/lookup_set_cache.h b/src/mongo/db/pipeline/lookup_set_cache.h index d482b267ee8..1b7adbd00a6 100644 --- a/src/mongo/db/pipeline/lookup_set_cache.h +++ b/src/mongo/db/pipeline/lookup_set_cache.h @@ -81,9 +81,12 @@ public: * ValueComparator. This requires instantiating the multi_index_container with comparison and * hasher functions obtained from the comparator. */ - LookupSetCache(ValueComparator valueComparator) - : _valueComparator(std::move(valueComparator)), - _container(makeIndexedContainer(_valueComparator)) {} + explicit LookupSetCache(const ValueComparator& comparator) + : _container(boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), + boost::make_tuple(0, + member(), + comparator.getHasher(), + comparator.getEqualTo()))) {} /** * Insert "value" into the set with key "key". If "key" is already present in the cache, move it @@ -186,29 +189,7 @@ public: return boost::none; } - /** - * Binds the cache to a new comparator that should be used to make all subsequent Value - * comparisons. - * - * TODO SERVER-25535: Remove this method. - */ - void setValueComparator(ValueComparator valueComparator) { - _valueComparator = std::move(valueComparator); - _container = makeIndexedContainer(_valueComparator); - } - private: - IndexedContainer makeIndexedContainer(const ValueComparator& valueComparator) const { - return IndexedContainer( - boost::make_tuple(IndexedContainer::nth_index<0>::type::ctor_args(), - boost::make_tuple(0, - member(), - valueComparator.getHasher(), - valueComparator.getEqualTo()))); - } - - ValueComparator _valueComparator; - IndexedContainer _container; size_t _memoryUsage = 0; diff --git a/src/mongo/db/pipeline/lookup_set_cache_test.cpp b/src/mongo/db/pipeline/lookup_set_cache_test.cpp index ce7503d4c6e..54a8b990ae7 100644 --- a/src/mongo/db/pipeline/lookup_set_cache_test.cpp +++ b/src/mongo/db/pipeline/lookup_set_cache_test.cpp @@ -52,9 +52,10 @@ BSONObj intToObj(int value) { return BSON("n" << value); } +const ValueComparator defaultComparator{nullptr}; + TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(1)); cache.insert(Value(0), intToObj(2)); cache.insert(Value(0), intToObj(3)); @@ -70,8 +71,7 @@ TEST(LookupSetCacheTest, InsertAndRetrieveWorksCorrectly) { } TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -90,8 +90,7 @@ TEST(LookupSetCacheTest, CacheDoesEvictInExpectedOrder) { } TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -106,8 +105,7 @@ TEST(LookupSetCacheTest, ReadDoesMoveKeyToFrontOfCache) { } TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -120,8 +118,7 @@ TEST(LookupSetCacheTest, InsertDoesPutKeyInMiddle) { } TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); cache.insert(Value(0), intToObj(0)); cache.insert(Value(1), intToObj(0)); @@ -134,8 +131,7 @@ TEST(LookupSetCacheTest, EvictDoesRespectMemoryUsage) { } TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { - const StringData::ComparatorInterface* stringComparator = nullptr; - LookupSetCache cache(stringComparator); + LookupSetCache cache(defaultComparator); for (int i = 0; i < 5; i++) { for (int j = 0; j < 5; j++) { @@ -173,7 +169,8 @@ TEST(LookupSetCacheTest, ComplexAccessPatternDoesBehaveCorrectly) { TEST(LookupSetCacheTest, CacheKeysRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), intToObj(1)); cache.insert(Value("FOO"_sd), intToObj(2)); @@ -196,7 +193,8 @@ TEST(LookupSetCacheTest, CacheKeysRespectCollation) { // foreign collection. TEST(LookupSetCacheTest, CachedValuesDontRespectCollation) { CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kToLowerString); - LookupSetCache cache(&collator); + ValueComparator comparator{&collator}; + LookupSetCache cache(comparator); cache.insert(Value("foo"_sd), BSON("foo" diff --git a/src/mongo/db/pipeline/parsed_add_fields.cpp b/src/mongo/db/pipeline/parsed_add_fields.cpp index 2bca98f21bd..ae78664a8a5 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields.cpp @@ -38,7 +38,8 @@ namespace mongo { namespace parsed_aggregation_projection { -std::unique_ptr ParsedAddFields::create(const BSONObj& spec) { +std::unique_ptr ParsedAddFields::create( + const boost::intrusive_ptr& expCtx, const BSONObj& spec) { // Verify that we don't have conflicting field paths, etc. Status status = ProjectionSpecValidator::validate(spec); if (!status.isOK()) { @@ -48,17 +49,19 @@ std::unique_ptr ParsedAddFields::create(const BSONObj& spec) { std::unique_ptr parsedAddFields = stdx::make_unique(); // Actually parse the specification. - parsedAddFields->parse(spec); + parsedAddFields->parse(expCtx, spec); return parsedAddFields; } -void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& variablesParseState) { +void ParsedAddFields::parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState) { for (auto elem : spec) { auto fieldName = elem.fieldNameStringData(); if (elem.type() == BSONType::Object) { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. } else { // The field name might be a dotted path. If so, we need to keep adding children @@ -72,12 +75,12 @@ void ParsedAddFields::parse(const BSONObj& spec, const VariablesParseState& vari // It is illegal to construct an empty FieldPath, so the above loop ends one // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -96,7 +99,8 @@ Document ParsedAddFields::applyProjection(Document inputDoc, Variables* vars) co return output.freeze(); } -bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, +bool ParsedAddFields::parseObjectAsExpression(const boost::intrusive_ptr& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { if (objSpec.firstElementFieldName()[0] == '$') { @@ -104,13 +108,14 @@ bool ParsedAddFields::parseObjectAsExpression(StringData pathToObject, // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedAddFields::parseSubObject(const BSONObj& subObj, +void ParsedAddFields::parseSubObject(const boost::intrusive_ptr& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node) { for (auto&& elem : subObj) { @@ -123,17 +128,18 @@ void ParsedAddFields::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (!parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { // It was a nested subobject auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); } } else { // This is a literal or regular value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_add_fields.h b/src/mongo/db/pipeline/parsed_add_fields.h index 58b8445e727..982fe6a3017 100644 --- a/src/mongo/db/pipeline/parsed_add_fields.h +++ b/src/mongo/db/pipeline/parsed_add_fields.h @@ -58,7 +58,8 @@ public: * Verifies that there are no conflicting paths in the specification. * Overrides the ParsedAggregationProjection's create method. */ - static std::unique_ptr create(const BSONObj& spec); + static std::unique_ptr create( + const boost::intrusive_ptr& expCtx, const BSONObj& spec); ProjectionType getType() const final { return ProjectionType::kComputed; @@ -67,10 +68,10 @@ public: /** * Parses the addFields specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique(idGenerator.getIdCount()); } @@ -87,10 +88,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::SEE_NEXT; @@ -124,7 +121,9 @@ private: /** * Parses 'spec' to determine which fields to add. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -134,7 +133,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -142,7 +142,8 @@ private: * Traverses 'subObj' and parses each field. Adds any computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_add_fields_test.cpp b/src/mongo/db/pipeline/parsed_add_fields_test.cpp index 557e2450dce..d7b68afca1d 100644 --- a/src/mongo/db/pipeline/parsed_add_fields_test.cpp +++ b/src/mongo/db/pipeline/parsed_add_fields_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -47,69 +48,78 @@ namespace { using std::vector; // These ParsedAddFields spec tests are a subset of the ParsedAggregationProjection creation tests. -// ParsedAddField should behave the same way, but does not use the same creation, so we include +// ParsedAddFields should behave the same way, but does not use the same creation, so we include // an abbreviation of the same tests here. // Verify that ParsedAddFields rejects specifications with conflicting field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithConflictingFieldPaths) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // These specs contain the same exact path. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b" << 1 << "b" << 2))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << 3 << "_id" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b" << 1 << "b" << 2))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << 3 << "_id" << true)), + UserException); // These specs contain overlapping paths. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << 1 << "a.b" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a.b.c" << 1 << "a" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("_id" << true << "_id.x" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << 1 << "a.b" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a.b.c" << 1 << "a" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("_id" << true << "_id.x" << true)), + UserException); } // Verify that ParsedAddFields rejects specifications that contain invalid field paths. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidFieldPath) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Dotted subfields are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); // The user cannot start a field with $. - ASSERT_THROWS(ParsedAddFields::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("c.$d" << true)), UserException); // Empty field names should throw an error. - ASSERT_THROWS(ParsedAddFields::create(BSON("" << 2)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << 2)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSON("" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("" << BSON("a" << true))), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON(".a" << true)), UserException); } // Verify that ParsedAddFields rejects specifications that contain empty objects or invalid // expressions. TEST(ParsedAddFieldsSpec, ThrowsOnCreationWithInvalidObjectsOrExpressions) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Invalid expressions should be rejected. - ASSERT_THROWS( - ParsedAddFields::create(BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), - UserException); - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSON("$gt" << BSON("bad" + ASSERT_THROWS(ParsedAddFields::create( + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); ASSERT_THROWS(ParsedAddFields::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); // Empty specifications are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSONObj()), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSONObj()), UserException); // Empty nested objects are not allowed. - ASSERT_THROWS(ParsedAddFields::create(BSON("a" << BSONObj())), UserException); + ASSERT_THROWS(ParsedAddFields::create(expCtx, BSON("a" << BSONObj())), UserException); } TEST(ParsedAddFields, DoesNotErrorOnTwoNestedFields) { - ParsedAddFields::create(BSON("a.b" << true << "a.c" << true)); - ParsedAddFields::create(BSON("a.b" << true << "a" << BSON("c" << true))); + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAddFields::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // Verify that replaced fields are not included as dependencies. TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -121,8 +131,9 @@ TEST(ParsedAddFieldsDeps, RemovesReplaceFieldsFromDependencies) { // Verify that adding nested fields keeps the top-level field as a dependency. TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFields) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" << true)); + addition.parse(expCtx, BSON("x.y" << true)); DepsTracker deps; addition.addDependencies(&deps); @@ -135,8 +146,10 @@ TEST(ParsedAddFieldsDeps, IncludesTopLevelFieldInDependenciesWhenAddingNestedFie // Verify that fields that an expression depends on are added to the dependencies. TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("x.y" + addition.parse(expCtx, + BSON("x.y" << "$z" << "a" << "$b")); @@ -155,8 +168,9 @@ TEST(ParsedAddFieldsDeps, AddsDependenciesForComputedFields) { // Verify that the serialization produces the correct output: converting numbers and literals to // their corresponding $const form. TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + addition.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); auto expectedSerialization = Document( fromjson("{a: {$add: [\"$a\", {$const: 2}]}, b: {d: {$const: 3}}, x: {y: {$const: 4}}}")); @@ -168,8 +182,9 @@ TEST(ParsedAddFieldsSerialize, SerializesToCorrectForm) { // Verify that serialize treats the _id field as any other field: including when explicity included. TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" << false)); + addition.parse(expCtx, BSON("_id" << false)); // Adds explicit "_id" setting field, serializes expressions. auto expectedSerialization = Document(fromjson("{_id: {$const: false}}")); @@ -184,8 +199,9 @@ TEST(ParsedAddFieldsSerialize, AddsIdToSerializeWhenExplicitlyIncluded) { // yet they derive from the same parent class. If the parent class were to change, this test would // fail. TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); // Does not implicitly include "_id" field. auto expectedSerialization = Document(fromjson("{a: {$const: true}}")); @@ -197,8 +213,9 @@ TEST(ParsedAddFieldsSerialize, OmitsIdFromSerializeWhenNotIncluded) { // Verify that the $addFields stage optimizes expressions into simpler forms when possible. TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"$const", 3}}}}; @@ -209,8 +226,9 @@ TEST(ParsedAddFieldsOptimize, OptimizesTopLevelExpressions) { // Verify that the $addFields stage optimizes expressions even when they are nested. TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + addition.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); addition.optimize(); auto expectedSerialization = Document{{"a", Document{{"b", Document{{"$const", 3}}}}}}; @@ -225,8 +243,9 @@ TEST(ParsedAddFieldsOptimize, ShouldOptimizeNestedExpressions) { // Verify that a new field is added to the end of the document. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // There are no fields in the document. auto result = addition.applyProjection(Document{}); @@ -241,8 +260,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldToEndOfDocument) { // Verify that an existing field is replaced and stays in the same order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("c" << 3)); + addition.parse(expCtx, BSON("c" << 3)); // Specified field is the only field in the document, and is replaced. auto result = addition.applyProjection(Document{{"c", 1}}); @@ -257,8 +277,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesFieldThatAlreadyExistsInDocument) { // Verify that replacing multiple fields preserves the original field order in the document. TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFieldOrder) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("second" + addition.parse(expCtx, + BSON("second" << "SECOND" << "first" << "FIRST")); @@ -269,8 +291,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesMultipleFieldsWhilePreservingInputFie // Verify that adding multiple fields adds the fields in the order specified. TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecified) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "secondComputed" << "SECOND")); @@ -286,8 +310,10 @@ TEST(ParsedAddFieldsExecutionTest, AddsNewFieldsAfterExistingFieldsInOrderSpecif // Verify that both adding and replacing fields at the same time follows the same rules as doing // each independently. TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRulesAsSeparately) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("firstComputed" + addition.parse(expCtx, + BSON("firstComputed" << "FIRST" << "second" << "SECOND")); @@ -300,8 +326,10 @@ TEST(ParsedAddFieldsExecutionTest, ReplacesAndAddsNewFieldsWithSameOrderingRules // Verify that _id is included just like a regular field, in whatever order it appears in the // input document, when adding new fields. TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("newField" + addition.parse(expCtx, + BSON("newField" << "computedVal")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"newField", "computedVal"_sd}}; @@ -314,8 +342,10 @@ TEST(ParsedAddFieldsExecutionTest, IdFieldIsKeptInOrderItAppearsInInputDocument) // Verify that replacing or adding _id works just like any other field. TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id" + addition.parse(expCtx, + BSON("_id" << "newId")); auto result = addition.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "newId"_sd}, {"a", 1}}; @@ -336,8 +366,9 @@ TEST(ParsedAddFieldsExecutionTest, ShouldReplaceIdWithComputedId) { // Verify that adding a dotted field keeps the other fields in the subdocument. TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedFieldToSubDoc) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = addition.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -362,8 +393,9 @@ TEST(ParsedAddFieldsExecutionTest, KeepsExistingSubFieldsWhenAddingSimpleDottedF // Verify that creating a dotted field creates the subdocument structure necessary. TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("sub.target" << true)); + addition.parse(expCtx, BSON("sub.target" << true)); // Should add the path if it doesn't exist. auto result = addition.applyProjection(Document{}); @@ -379,8 +411,9 @@ TEST(ParsedAddFieldsExecutionTest, CreatesSubDocIfDottedAddedFieldDoesNotExist) // Verify that adding a dotted value to an array field sets the field in every element of the array. // SERVER-25200: make this agree with $set. TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b" << true)); + addition.parse(expCtx, BSON("a.b" << true)); vector nestedValues = {Value(1), Value(Document{}), @@ -404,8 +437,10 @@ TEST(ParsedAddFieldsExecutionTest, AppliesDottedAdditionToEachElementInArray) { // Verify that creation of the subdocument structure works for many layers of nesting. TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedField) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a.b.c.d" + addition.parse(expCtx, + BSON("a.b.c.d" << "computedVal")); // Should add the path if it doesn't exist. @@ -421,8 +456,10 @@ TEST(ParsedAddFieldsExecutionTest, CreatesNestedSubDocumentsAllTheWayToAddedFiel // Verify that _id is not special: we can add subfields to it as well. TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("_id.X" << true << "_id.Z" + addition.parse(expCtx, + BSON("_id.X" << true << "_id.Z" << "NEW")); auto result = addition.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", true}, {"Y", 2}, {"Z", "NEW"_sd}}}}; @@ -432,10 +469,12 @@ TEST(ParsedAddFieldsExecutionTest, AddsSubFieldsOfId) { // Verify that both ways of specifying nested fields -- both dotted notation and nesting -- // can be used together in the same specification. TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - addition.parse(BSON("a.b" << true << "a.c" << true << "a.W" + addition.parse(expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << "W" << "a.X" << "X" @@ -462,8 +501,10 @@ TEST(ParsedAddFieldsExecutionTest, ShouldAllowMixedNestedAndDottedFields) { // Verify that adding nested fields preserves the addition order in the spec. TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("b.d" + addition.parse(expCtx, + BSON("b.d" << "FIRST" << "b.c" << "SECOND")); @@ -478,8 +519,9 @@ TEST(ParsedAddFieldsExecutionTest, AddsNestedAddedFieldsInOrderSpecified) { // Verify that the metadata is kept from the original input document. TEST(ParsedAddFieldsExecutionTest, AlwaysKeepsMetadataFromOriginalDoc) { + boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ParsedAddFields addition; - addition.parse(BSON("a" << true)); + addition.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp index a73a9589ed8..69b196faab3 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.cpp @@ -259,7 +259,7 @@ private: } // namespace std::unique_ptr ParsedAggregationProjection::create( - const BSONObj& spec) { + const boost::intrusive_ptr& expCtx, const BSONObj& spec) { // Check that the specification was valid. Status returned is unspecific because validate() // is used by the $addFields stage as well as $project. // If there was an error, uassert with a $project-specific message. @@ -282,7 +282,7 @@ std::unique_ptr ParsedAggregationProjection::create : static_cast(new ParsedExclusionProjection())); // Actually parse the specification. - parsedProject->parse(spec); + parsedProject->parse(expCtx, spec); return parsedProject; } diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.h b/src/mongo/db/pipeline/parsed_aggregation_projection.h index a637365565e..8d92e4ab768 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.h +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.h @@ -41,7 +41,7 @@ namespace mongo { class BSONObj; class Document; -struct ExpressionContext; +class ExpressionContext; namespace parsed_aggregation_projection { @@ -117,7 +117,8 @@ public: * * Throws a UserException if 'spec' is an invalid projection specification. */ - static std::unique_ptr create(const BSONObj& spec); + static std::unique_ptr create( + const boost::intrusive_ptr& expCtx, const BSONObj& spec); virtual ~ParsedAggregationProjection() = default; @@ -132,18 +133,14 @@ public: * inclusions and exclusions. 'variablesParseState' is used by any contained expressions to * track which variables are defined so that they can later be referenced at execution time. */ - virtual void parse(const BSONObj& spec) = 0; + virtual void parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec) = 0; /** * Optimize any expressions contained within this projection. */ virtual void optimize() {} - /** - * Inject the ExpressionContext into any expressions contained within this projection. - */ - virtual void injectExpressionContext(const boost::intrusive_ptr& expCtx) {} - /** * Add any dependencies needed by this projection or any sub-expressions to 'deps'. */ diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp index 5b55071a6bf..71dd1c2378f 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_aggregation_projection_test.cpp @@ -36,6 +36,7 @@ #include "mongo/bson/bsonobjbuilder.h" #include "mongo/bson/json.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,262 +54,296 @@ BSONObj wrapInLiteral(const T& arg) { // TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateFieldNames) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Include/exclude the same field twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << false)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "b" << false))), UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "b" << false))), - UserException); // Mix of include/exclude and adding a field. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a" << true)), - UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a" << true)), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "a" << wrapInLiteral(0))), UserException); // Adding the same field twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDuplicateIds) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << false << "_id" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id" << true)), UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "_id" << false)), UserException); + // Mix of including/excluding and adding _id. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << true)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "_id" << wrapInLiteral(0))), + UserException); + // Adding _id twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldsWithSharedPrefix) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Include/exclude Fields with a shared prefix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "a.b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "a.b" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << false << "a" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << false << "a" << false)), UserException); // Mix of include/exclude and adding a shared prefix. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "a.b" << true)), - UserException); - ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << false << "a" << wrapInLiteral(0))), + UserException); // Adding a shared prefix twice. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), + expCtx, BSON("a" << wrapInLiteral(1) << "a.b" << wrapInLiteral(0))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), + expCtx, BSON("a.b.c.d" << wrapInLiteral(1) << "a.b.c" << wrapInLiteral(0))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfIdAndSubFieldsOfId) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Include/exclude _id twice. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "_id.x" << true)), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << false)), - UserException); - - // Mix of including/excluding and adding _id. ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "_id.x" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id.x" << false << "_id" << wrapInLiteral(0))), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false << "_id" << false)), UserException); - // Adding _id twice. + // Mix of including/excluding and adding _id. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << true)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + expCtx, BSON("_id.x" << false << "_id" << wrapInLiteral(0))), UserException); + + // Adding _id twice. + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << wrapInLiteral(1) << "_id.x" << wrapInLiteral(0))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("_id.b.c.d" << wrapInLiteral(1) << "_id.b.c" << wrapInLiteral(0))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectMixOfInclusionAndExclusion) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Simple mix. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << true)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << false << "c" << true))), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("_id" << BSON("b" << false << "c" << true))), + UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("_id" << BSON("b" << false << "c" << true))), + ParsedAggregationProjection::create(expCtx, BSON("_id.b" << false << "a.c" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.b" << false << "a.c" << true)), - UserException); // Mix while also adding a field. ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), + expCtx, BSON("a" << true << "b" << wrapInLiteral(1) << "c" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), + expCtx, BSON("a" << false << "b" << wrapInLiteral(1) << "c" << true)), UserException); // Mixing "_id" inclusion with exclusion. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a" << false)), - UserException); - - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "_id" << true)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id" << true << "a.b.c" << false)), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << false << "_id" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("_id.x" << true << "a.b.c" << false)), - UserException); -} - -TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << false << "b" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("_id" << true << "a.b.c" << false)), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1) << "b" << false)), + ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true << "a.b.c" << false)), UserException); +} +TEST(ParsedAggregationProjectionType, ShouldRejectMixOfExclusionAndComputedFields) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << wrapInLiteral(1))), UserException); ASSERT_THROWS( - ParsedAggregationProjection::create(BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1) << "b" << false)), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), + expCtx, BSON("a.b" << false << "a.c" << wrapInLiteral(1))), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a.b" << wrapInLiteral(1) << "a.c" << false)), + UserException); + + ASSERT_THROWS(ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << false << "c" << wrapInLiteral(1)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), + expCtx, BSON("a" << BSON("b" << wrapInLiteral(1) << "c" << false))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectDottedFieldInSubDocument) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << true))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << true))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b.c" << wrapInLiteral(1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectFieldNamesStartingWithADollar) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$dollar" << 1)), UserException); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$dollar" << 1)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b.$dollar" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 0)), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b.$dollar" << 1)), + UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 0))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 0))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("b" << BSON("$dollar" << 1))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("b" << BSON("$dollar" << 1))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 0)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << 1)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 0)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << 1)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectTopLevelExpressions) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$add" << BSON_ARRAY(4 << 2))), + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$add" << BSON_ARRAY(4 << 2))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectExpressionWithMultipleFieldNames) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), + expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(4 << 2) << "b" << 1))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), - UserException); - ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + expCtx, BSON("a" << BSON("b" << 1 << "$add" << BSON_ARRAY(4 << 2)))), UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("c" << 1 << "$add" << BSON_ARRAY(4 << 2))))), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create( + expCtx, BSON("a" << BSON("b" << BSON("$add" << BSON_ARRAY(4 << 2) << "c" << 1)))), + UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyProjection) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSONObj()), UserException); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSONObj()), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldRejectEmptyNestedObject) { - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << false << "b" << BSONObj())), + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << true << "b" << BSONObj())), + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << false << "b" << BSONObj())), + UserException); + ASSERT_THROWS( + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << BSONObj())), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a.b" << BSONObj())), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("b" << BSONObj()))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("b" << BSONObj()))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidExpression) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << false << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); ASSERT_THROWS(ParsedAggregationProjection::create( - BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), + expCtx, BSON("a" << true << "b" << BSON("$unknown" << BSON_ARRAY(4 << 2)))), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldErrorOnInvalidFieldPath) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Empty field names. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << wrapInLiteral(2))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a" << BSON("" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a" << BSON("" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << true))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << true))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("" << BSON("a" << false))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("" << BSON("a" << false))), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("a." << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("a." << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON(".a" << false)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON(".a" << false)), UserException); // Not testing field names with null bytes, since that is invalid BSON, and won't make it to the // $project stage without a previous error. // Field names starting with '$'. - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("$x" << wrapInLiteral(2))), + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("$x" << wrapInLiteral(2))), + UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << true)), UserException); + ASSERT_THROWS(ParsedAggregationProjection::create(expCtx, BSON("c.$d" << false)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << true)), UserException); - ASSERT_THROWS(ParsedAggregationProjection::create(BSON("c.$d" << false)), UserException); } TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << true)); - ParsedAggregationProjection::create(BSON("a.b" << true << "a" << BSON("c" << true))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a.c" << true)); + ParsedAggregationProjection::create(expCtx, BSON("a.b" << true << "a" << BSON("c" << true))); } // @@ -316,102 +351,112 @@ TEST(ParsedAggregationProjectionErrors, ShouldNotErrorOnTwoNestedFields) { // TEST(ParsedAggregationProjectionType, ShouldDefaultToInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("_id" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectExclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << false))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << false))); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << false)); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } TEST(ParsedAggregationProjectionType, ShouldDetectInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << true)); + parsedProject = + ParsedAggregationProjection::create(expCtx, BSON("_id" << false << "a.b.c" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << true)); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << true)); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id" << BSON("x" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("x" << BSON("_id" << true))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << true))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldTreatOnlyComputedFieldsAsAnInclusionProjection) { - auto parsedProject = ParsedAggregationProjection::create(BSON("a" << wrapInLiteral(1))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + auto parsedProject = ParsedAggregationProjection::create(expCtx, BSON("a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("_id" << false << "a.b.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = ParsedAggregationProjection::create(BSON("_id.x" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create(expCtx, BSON("_id.x" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("_id" << BSON("x" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("_id" << BSON("x" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = - ParsedAggregationProjection::create(BSON("x" << BSON("_id" << wrapInLiteral(1)))); + ParsedAggregationProjection::create(expCtx, BSON("x" << BSON("_id" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldAllowMixOfInclusionAndComputedFields) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); auto parsedProject = - ParsedAggregationProjection::create(BSON("a" << true << "b" << wrapInLiteral(1))); + ParsedAggregationProjection::create(expCtx, BSON("a" << true << "b" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); - parsedProject = - ParsedAggregationProjection::create(BSON("a.b" << true << "a.c" << wrapInLiteral(1))); + parsedProject = ParsedAggregationProjection::create( + expCtx, BSON("a.b" << true << "a.c" << wrapInLiteral(1))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); parsedProject = ParsedAggregationProjection::create( - BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); + expCtx, BSON("a" << BSON("b" << true << "c" << wrapInLiteral(1)))); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } TEST(ParsedAggregationProjectionType, ShouldCoerceNumericsToBools) { + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); std::vector zeros = {Value(0), Value(0LL), Value(0.0), Value(Decimal128(0))}; for (auto&& zero : zeros) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", zero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", zero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kExclusion); } std::vector nonZeroes = { Value(1), Value(-1), Value(3), Value(1LL), Value(1.0), Value(Decimal128(1))}; for (auto&& nonZero : nonZeroes) { - auto parsedProject = ParsedAggregationProjection::create(Document{{"a", nonZero}}.toBson()); + auto parsedProject = + ParsedAggregationProjection::create(expCtx, Document{{"a", nonZero}}.toBson()); ASSERT(parsedProject->getType() == ProjectionType::kInclusion); } } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp index 8226d146009..0f800f9a112 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.cpp @@ -143,7 +143,10 @@ Document ParsedExclusionProjection::applyProjection(Document inputDoc) const { return _root->applyProjection(inputDoc); } -void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, size_t depth) { +void ParsedExclusionProjection::parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth) { for (auto elem : spec) { const auto fieldName = elem.fieldNameStringData().toString(); @@ -188,7 +191,7 @@ void ParsedExclusionProjection::parse(const BSONObj& spec, ExclusionNode* node, child = child->addOrGetChild(fullPath.fullPath()); } - parse(elem.Obj(), child, depth + 1); + parse(expCtx, elem.Obj(), child, depth + 1); break; } default: { MONGO_UNREACHABLE; } diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection.h b/src/mongo/db/pipeline/parsed_exclusion_projection.h index ea7b25ac33f..d0988d2d2cb 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_exclusion_projection.h @@ -108,8 +108,8 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { - parse(spec, _root.get(), 0); + void parse(const boost::intrusive_ptr& expCtx, const BSONObj& spec) final { + parse(expCtx, spec, _root.get(), 0); } /** @@ -134,7 +134,10 @@ private: * Traverses 'spec' and parses each field. Adds any excluded fields at this level to 'node', * and recurses on any sub-objects. */ - void parse(const BSONObj& spec, ExclusionNode* node, size_t depth); + void parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, + ExclusionNode* node, + size_t depth); // The ExclusionNode tree does most of the execution work once constructed. diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp index 62a23b2c729..76b458260d0 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp @@ -40,6 +40,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" @@ -57,28 +58,32 @@ DEATH_TEST(ExclusionProjection, ShouldRejectComputedField, "Invariant failure fieldName[0] != '$'") { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Top-level expression. - exclusion.parse(BSON("a" << false << "b" << BSON("$literal" << 1))); + exclusion.parse(expCtx, BSON("a" << false << "b" << BSON("$literal" << 1))); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedField, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << true)); } DEATH_TEST(ExclusionProjection, ShouldFailWhenGivenIncludedId, "Invariant failure !elem.trueValue()") { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << true << "a" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << true << "a" << false)); } TEST(ExclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); exclusion.parse( - fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); + expCtx, fromjson("{a: 0, b: {c: NumberLong(0), d: 0.0}, 'x.y': false, _id: NumberInt(0)}")); // Converts numbers to bools, converts dotted paths to nested documents. Note order of excluded // fields is subject to change. @@ -107,7 +112,9 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { // later. If there are no later stages, then we will finish the dependency computation // cycle without full knowledge of which fields are needed, and thus include all the fields. ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("_id" << false << "a" << false << "b.c" << false << "x.y.z" << false)); DepsTracker deps; exclusion.addDependencies(&deps); @@ -119,7 +126,8 @@ TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false << "a" << false << "b.c" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false << "a" << false << "b.c" << false)); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -131,7 +139,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModified) { TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNestedObj) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << BSON("b" << false << "c" << BSON("d" << false)))); auto modifiedPaths = exclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kFiniteSet); @@ -146,7 +155,8 @@ TEST(ExclusionProjection, ShouldReportExcludedFieldsAsModifiedWhenSpecifiedAsNes TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); // More than one field in document. auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -171,7 +181,9 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, + BSON("a" << Value(0) << "b" << Value(0LL) << "c" << Value(0.0) << "d" << Value(Decimal128(0)))); auto result = @@ -182,7 +194,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("second" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("second" << false)); auto result = exclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"first", 0}, {"third", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -190,7 +203,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}, {"_id", "ID"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -198,7 +212,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false << "_id" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false << "_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -206,7 +221,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}, {"b", 2}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -218,7 +234,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdAndKeepAllOtherFields) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("_id.x" << false << "_id" << BSON("y" << false))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("_id.x" << false << "_id" << BSON("y" << false))); auto result = exclusion.applyProjection( Document{{"_id", Document{{"x", 1}, {"y", 2}, {"z", 3}}}, {"a", 1}}); auto expectedResult = Document{{"_id", Document{{"z", 3}}}, {"a", 1}}; @@ -227,7 +244,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); // More than one field in sub document. auto result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -252,7 +270,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFieldDoesNotExist) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("sub.target" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("sub.target" << false)); // Should not add the path if it doesn't exist. auto result = exclusion.applyProjection(Document{}); @@ -267,7 +286,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFiel TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementInArray) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a.b" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a.b" << false)); std::vector nestedValues = { Value(1), @@ -290,8 +310,10 @@ TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementIn TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedExclusionProjection exclusion; + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Exclude all of "a.b", "a.c", "a.d", and "a.e". exclusion.parse( + expCtx, BSON("a.b" << false << "a.c" << false << "a" << BSON("d" << false << "e" << false))); auto result = exclusion.applyProjection( Document{{"a", Document{{"b", 1}, {"c", 2}, {"d", 3}, {"e", 4}, {"f", 5}}}}); @@ -301,7 +323,8 @@ TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedExclusionProjection exclusion; - exclusion.parse(BSON("a" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + exclusion.parse(expCtx, BSON("a" << false)); MutableDocument inputDocBuilder(Document{{"_id", "ID"_sd}, {"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp index abac372be7e..511d1b37fff 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp @@ -54,16 +54,6 @@ void InclusionNode::optimize() { } } -void InclusionNode::injectExpressionContext(const boost::intrusive_ptr& expCtx) { - for (auto&& expressionIt : _expressions) { - expressionIt.second->injectExpressionContext(expCtx); - } - - for (auto&& childPair : _children) { - childPair.second->injectExpressionContext(expCtx); - } -} - void InclusionNode::serialize(MutableDocument* output, bool explain) const { // Always put "_id" first if it was included (implicitly or explicitly). if (_inclusions.find("_id") != _inclusions.end()) { @@ -252,7 +242,8 @@ void InclusionNode::addPreservedPaths(std::set* preservedPaths) con // ParsedInclusionProjection // -void ParsedInclusionProjection::parse(const BSONObj& spec, +void ParsedInclusionProjection::parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, const VariablesParseState& variablesParseState) { // It is illegal to specify a projection with no output fields. bool atLeastOneFieldInOutput = false; @@ -289,7 +280,7 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, } case BSONType::Object: { // This is either an expression, or a nested specification. - if (parseObjectAsExpression(fieldName, elem.Obj(), variablesParseState)) { + if (parseObjectAsExpression(expCtx, fieldName, elem.Obj(), variablesParseState)) { // It was an expression. break; } @@ -306,13 +297,14 @@ void ParsedInclusionProjection::parse(const BSONObj& spec, // iteration too soon. Add the last path here. child = child->addOrGetChild(remainingPath.fullPath()); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. - _root->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + _root->addComputedField( + FieldPath(elem.fieldName()), + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } @@ -343,6 +335,7 @@ Document ParsedInclusionProjection::applyProjection(Document inputDoc, Variables } bool ParsedInclusionProjection::parseObjectAsExpression( + const boost::intrusive_ptr& expCtx, StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState) { @@ -351,15 +344,17 @@ bool ParsedInclusionProjection::parseObjectAsExpression( // field. invariant(objSpec.nFields() == 1); _root->addComputedField(pathToObject, - Expression::parseExpression(objSpec, variablesParseState)); + Expression::parseExpression(expCtx, objSpec, variablesParseState)); return true; } return false; } -void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, - const VariablesParseState& variablesParseState, - InclusionNode* node) { +void ParsedInclusionProjection::parseSubObject( + const boost::intrusive_ptr& expCtx, + const BSONObj& subObj, + const VariablesParseState& variablesParseState, + InclusionNode* node) { for (auto elem : subObj) { invariant(elem.fieldName()[0] != '$'); // Dotted paths in a sub-object have already been disallowed in @@ -381,19 +376,20 @@ void ParsedInclusionProjection::parseSubObject(const BSONObj& subObj, // This is either an expression, or a nested specification. auto fieldName = elem.fieldNameStringData().toString(); if (parseObjectAsExpression( + expCtx, FieldPath::getFullyQualifiedPath(node->getPath(), fieldName), elem.Obj(), variablesParseState)) { break; } auto child = node->addOrGetChild(fieldName); - parseSubObject(elem.Obj(), variablesParseState, child); + parseSubObject(expCtx, elem.Obj(), variablesParseState, child); break; } default: { // This is a literal value. node->addComputedField(FieldPath(elem.fieldName()), - Expression::parseOperand(elem, variablesParseState)); + Expression::parseOperand(expCtx, elem, variablesParseState)); } } } diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.h b/src/mongo/db/pipeline/parsed_inclusion_projection.h index 8cd04514b53..2f812af300d 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.h @@ -119,8 +119,6 @@ public: return _pathToNode; } - void injectExpressionContext(const boost::intrusive_ptr& expCtx); - /** * Recursively add all paths that are preserved by this inclusion projection. */ @@ -184,10 +182,10 @@ public: /** * Parses the projection specification given by 'spec', populating internal data structures. */ - void parse(const BSONObj& spec) final { + void parse(const boost::intrusive_ptr& expCtx, const BSONObj& spec) final { VariablesIdGenerator idGenerator; VariablesParseState variablesParseState(&idGenerator); - parse(spec, variablesParseState); + parse(expCtx, spec, variablesParseState); _variables = stdx::make_unique(idGenerator.getIdCount()); } @@ -210,10 +208,6 @@ public: _root->optimize(); } - void injectExpressionContext(const boost::intrusive_ptr& expCtx) final { - _root->injectExpressionContext(expCtx); - } - DocumentSource::GetDepsReturn addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); return DocumentSource::EXHAUSTIVE_FIELDS; @@ -246,7 +240,9 @@ private: * Parses 'spec' to determine which fields to include, which are computed, and whether to * include '_id' or not. */ - void parse(const BSONObj& spec, const VariablesParseState& variablesParseState); + void parse(const boost::intrusive_ptr& expCtx, + const BSONObj& spec, + const VariablesParseState& variablesParseState); /** * Attempts to parse 'objSpec' as an expression like {$add: [...]}. Adds a computed field to @@ -256,7 +252,8 @@ private: * Throws an error if it was determined to be an expression specification, but failed to parse * as a valid expression. */ - bool parseObjectAsExpression(StringData pathToObject, + bool parseObjectAsExpression(const boost::intrusive_ptr& expCtx, + StringData pathToObject, const BSONObj& objSpec, const VariablesParseState& variablesParseState); @@ -264,7 +261,8 @@ private: * Traverses 'subObj' and parses each field. Adds any included or computed fields at this level * to 'node'. */ - void parseSubObject(const BSONObj& subObj, + void parseSubObject(const boost::intrusive_ptr& expCtx, + const BSONObj& subObj, const VariablesParseState& variablesParseState, InclusionNode* node); diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp index 611f5367e9a..c2610418a40 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp @@ -38,6 +38,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -53,19 +54,23 @@ BSONObj wrapInLiteral(const T& arg) { TEST(InclusionProjection, ShouldThrowWhenParsingInvalidExpression) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("a" << BSON("$gt" << BSON("bad" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, + BSON("a" << BSON("$gt" << BSON("bad" << "arguments")))), UserException); } TEST(InclusionProjection, ShouldRejectProjectionWithNoOutputFields) { ParsedInclusionProjection inclusion; - ASSERT_THROWS(inclusion.parse(BSON("_id" << false)), UserException); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + ASSERT_THROWS(inclusion.parse(expCtx, BSON("_id" << false)), UserException); } TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true << "x.y" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true << "x.y" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -78,7 +83,8 @@ TEST(InclusionProjection, ShouldAddIncludedFieldsToDependencies) { TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); DepsTracker deps; inclusion.addDependencies(&deps); @@ -90,7 +96,9 @@ TEST(InclusionProjection, ShouldAddIdToDependenciesIfNotSpecified) { TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << "$a" << "x" << "$z")); @@ -106,7 +114,9 @@ TEST(InclusionProjection, ShouldAddDependenciesOfComputedFields) { TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("x.y" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("x.y" << "$z")); DepsTracker deps; @@ -123,7 +133,8 @@ TEST(InclusionProjection, ShouldAddPathToDependenciesForNestedComputedFields) { TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { ParsedInclusionProjection inclusion; - inclusion.parse(fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, fromjson("{a: {$add: ['$a', 2]}, b: {d: 3}, 'x.y': {$literal: 4}}")); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document(fromjson( @@ -136,7 +147,8 @@ TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << false << "a" << true)); // Adds implicit "_id" inclusion, converts numbers to bools, serializes expressions. auto expectedSerialization = Document{{"_id", false}, {"a", true}}; @@ -149,7 +161,8 @@ TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -162,7 +175,8 @@ TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << BSON("$add" << BSON_ARRAY(1 << 2)))); inclusion.optimize(); @@ -176,10 +190,13 @@ TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON( - "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true - << "e.f" - << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse( + expCtx, + BSON("a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" + << true + << "e.f" + << true)); auto modifiedPaths = inclusion.getModifiedPaths(); ASSERT(modifiedPaths.type == DocumentSource::GetModPathsReturn::Type::kAllExcept); @@ -195,7 +212,9 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModified) { TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWithIdExclusion) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("_id" << false << "a" << wrapInLiteral("computedVal") << "b.c" << wrapInLiteral("computedVal") << "d" << true @@ -222,7 +241,8 @@ TEST(InclusionProjection, ShouldReportThatAllExceptIncludedFieldsAreModifiedWith TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); // More than one field in document. auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); @@ -247,7 +267,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -260,7 +281,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -268,7 +290,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedField TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("first" << true << "second" << true << "third" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("first" << true << "second" << true << "third" << true)); auto inputDoc = Document{{"second", 1}, {"first", 0}, {"third", 2}}; auto result = inclusion.applyProjection(inputDoc); ASSERT_DOCUMENT_EQ(result, inputDoc); @@ -276,7 +299,9 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("firstComputed" << wrapInLiteral("FIRST") << "secondComputed" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"firstComputed", "FIRST"_sd}, {"secondComputed", "SECOND"_sd}}; @@ -285,7 +310,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -298,7 +324,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFields) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"newField", "computedVal"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -306,7 +333,8 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFiel TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << true << "b" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << true << "b" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}, {"c", 3}}); auto expectedResult = Document{{"_id", "ID"_sd}, {"a", 1}, {"b", 2}}; @@ -315,7 +343,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true << "_id" << false)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true << "_id" << false)); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"a", 1}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -323,7 +352,8 @@ TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id" << wrapInLiteral("newId"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id" << wrapInLiteral("newId"))); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"_sd}}); auto expectedResult = Document{{"_id", "newId"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -335,7 +365,8 @@ TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); // More than one field in sub document. auto result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); @@ -360,7 +391,8 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFieldDoesNotExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << true)); // Should not add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -375,7 +407,8 @@ TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFiel TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << true)); vector nestedValues = {Value(1), Value(Document{}), @@ -399,7 +432,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementIn TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Other fields exist in sub document, one of which is the specified field. auto result = inclusion.applyProjection(Document{{"sub", Document{{"target", 1}, {"c", 2}}}}); @@ -419,7 +453,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDoesntExist) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("sub.target" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("sub.target" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -433,7 +468,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDo TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayToComputedField) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b.c.d" << wrapInLiteral("computedVal"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b.c.d" << wrapInLiteral("computedVal"))); // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); @@ -448,7 +484,8 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayTo TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.b" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.b" << wrapInLiteral("COMPUTED"))); vector nestedValues = {Value(1), Value(Document{}), @@ -471,7 +508,8 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElement TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachElementInArray) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a.inc" << true << "a.comp" << wrapInLiteral("COMPUTED"))); vector nestedValues = {Value(1), Value(Document{}), @@ -498,7 +536,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachEl TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); auto result = inclusion.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", 1}, {"Z", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -506,13 +545,16 @@ TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { ParsedInclusionProjection inclusion; + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); // Include all of "a.b", "a.c", "a.d", and "a.e". // Add new computed fields "a.W", "a.X", "a.Y", and "a.Z". - inclusion.parse(BSON( - "a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" << wrapInLiteral("X") - << "a" - << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" - << wrapInLiteral("Z")))); + inclusion.parse( + expCtx, + BSON("a.b" << true << "a.c" << true << "a.W" << wrapInLiteral("W") << "a.X" + << wrapInLiteral("X") + << "a" + << BSON("d" << true << "e" << true << "Y" << wrapInLiteral("Y") << "Z" + << wrapInLiteral("Z")))); auto result = inclusion.applyProjection(Document{ {"a", Document{{"b", "b"_sd}, {"c", "c"_sd}, {"d", "d"_sd}, {"e", "e"_sd}, {"f", "f"_sd}}}}); @@ -530,7 +572,9 @@ TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpecified) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, + BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"a", "FIRST"_sd}, {"b", Document{{"c", "SECOND"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -538,7 +582,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpe TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", Document{{"c", "NEW"_sd}}}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -557,7 +602,8 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppearAfterInclusions) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("b" << wrapInLiteral("NEW") << "a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("b" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"b", 1}, {"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", "NEW"_sd}}; ASSERT_DOCUMENT_EQ(result, expectedResult); @@ -572,7 +618,8 @@ TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppea TEST(InclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { ParsedInclusionProjection inclusion; - inclusion.parse(BSON("a" << true)); + const boost::intrusive_ptr expCtx(new ExpressionContextForTest()); + inclusion.parse(expCtx, BSON("a" << true)); MutableDocument inputDocBuilder(Document{{"a", 1}}); inputDocBuilder.setRandMetaField(1.0); diff --git a/src/mongo/db/pipeline/pipeline.cpp b/src/mongo/db/pipeline/pipeline.cpp index b296d3f85df..8bb745fba13 100644 --- a/src/mongo/db/pipeline/pipeline.cpp +++ b/src/mongo/db/pipeline/pipeline.cpp @@ -165,13 +165,6 @@ void Pipeline::reattachToOperationContext(OperationContext* opCtx) { } } -void Pipeline::injectExpressionContext(const intrusive_ptr& expCtx) { - pCtx = expCtx; - for (auto&& stage : _sources) { - stage->injectExpressionContext(pCtx); - } -} - intrusive_ptr Pipeline::splitForSharded() { // Create and initialize the shard spec we'll return. We start with an empty pipeline on the // shards and all work being done in the merger. Optimizations can move operations between diff --git a/src/mongo/db/pipeline/pipeline.h b/src/mongo/db/pipeline/pipeline.h index aff72fa40bc..fe4ca1de424 100644 --- a/src/mongo/db/pipeline/pipeline.h +++ b/src/mongo/db/pipeline/pipeline.h @@ -45,7 +45,7 @@ class BSONObjBuilder; class CollatorInterface; class DocumentSource; class OperationContext; -struct ExpressionContext; +class ExpressionContext; /** * A Pipeline object represents a list of DocumentSources and is responsible for optimizing the @@ -128,12 +128,6 @@ public: */ void optimizePipeline(); - /** - * Propagates a reference to the ExpressionContext to all of the pipeline's contained stages and - * expressions. - */ - void injectExpressionContext(const boost::intrusive_ptr& expCtx); - /** * Returns any other collections involved in the pipeline in addition to the collection the * aggregation is run on. diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index bb681edd3f1..b2a7bee2fa8 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -188,7 +188,6 @@ public: return pipeline.getStatus(); } - pipeline.getValue()->injectExpressionContext(expCtx); pipeline.getValue()->optimizePipeline(); AutoGetCollectionForRead autoColl(expCtx->opCtx, expCtx->ns); @@ -294,7 +293,8 @@ StatusWith> attemptToGetExecutor( // that the user omitted. // // If pipeline has a null collator (representing the "simple" collation), we simply set the - // collation option to the original user BSON. + // collation option to the original user BSON, which is either the empty object (unspecified), + // or the specification for the "simple" collation. qr->setCollation(pExpCtx->getCollator() ? pExpCtx->getCollator()->getSpec().toBSON() : pExpCtx->collation); diff --git a/src/mongo/db/pipeline/pipeline_d.h b/src/mongo/db/pipeline/pipeline_d.h index 35d93a4cfd2..5be4893cf3b 100644 --- a/src/mongo/db/pipeline/pipeline_d.h +++ b/src/mongo/db/pipeline/pipeline_d.h @@ -38,7 +38,7 @@ namespace mongo { class Collection; class DocumentSourceCursor; class DocumentSourceSort; -struct ExpressionContext; +class ExpressionContext; class OperationContext; class Pipeline; class PlanExecutor; diff --git a/src/mongo/db/pipeline/pipeline_test.cpp b/src/mongo/db/pipeline/pipeline_test.cpp index 41cd6e567d3..dba575017ce 100644 --- a/src/mongo/db/pipeline/pipeline_test.cpp +++ b/src/mongo/db/pipeline/pipeline_test.cpp @@ -39,7 +39,7 @@ #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_mock.h" #include "mongo/db/pipeline/document_value_test_util.h" -#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/expression_context_for_test.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/query/collation/collator_interface_mock.h" @@ -86,7 +86,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector{}}); + auto outputPipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); outputPipe->optimizePipeline(); @@ -240,17 +247,17 @@ class MoveMatchBeforeSort : public Base { class LookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -259,17 +266,17 @@ class LookupShouldCoalesceWithUnwindOnAs : public Base { class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -278,18 +285,18 @@ class LookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -298,13 +305,13 @@ class LookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; } string outputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -313,51 +320,59 @@ class LookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class LookupShouldSwapWithMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0}}]"; } string outputPipeJson() { return "[{$match: {independent: 0}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}]"; } }; class LookupShouldSplitMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'independent': 0, asField: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 0}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {asField: {$eq: 3}}}]"; } }; class LookupShouldNotAbsorbMatchOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$match: {'asField.subfield': 0}}]"; } }; class LookupShouldAbsorbUnwindMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: '$asField'}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}, " " matching: {subfield: {$eq: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " "{$unwind: {path: '$asField'}}, " "{$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -365,14 +380,15 @@ class LookupShouldAbsorbUnwindMatch : public Base { class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {'asField.subfield': {$eq: 1}, independentField: {$gt: 2}}}]"; } string outputPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -386,7 +402,8 @@ class LookupShouldAbsorbUnwindAndSplitAndAbsorbMatch : public Base { } string serializedPipeJson() { return "[{$match: {independentField: {$gt: 2}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {'asField.subfield': {$eq: 1}}}]"; } @@ -397,19 +414,21 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { // the $lookup, and if any child of the $or is independent of the 'asField', then the $match // cannot be absorbed by the $lookup. string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: '$asField'}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string outputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z', " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: 'z', " " unwinding: {preserveNullAndEmptyArrays: false}}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField'}}, " " {$match: {$or: [{'independent': {$gt: 4}}, " " {'asField.dependent': {$elemMatch: {a: {$eq: 1}}}}]}}]"; @@ -418,14 +437,15 @@ class LookupShouldNotSplitIndependentAndDependentOrClauses : public Base { class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {index: 0, 'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -438,7 +458,8 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', includeArrayIndex: 'index'}}, " " {$match: {$and: [{index: {$eq: 0}}, {'asField.value': {$gt: 0}}]}}]"; } @@ -446,14 +467,15 @@ class LookupWithMatchOnArrayIndexFieldShouldNotCoalesce : public Base { class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}, independent: 1}}]"; } string outputPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " " {$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'asField', " " localField: 'y', " " foreignField: 'z', " @@ -465,7 +487,8 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas } string serializedPipeJson() { return "[{$match: {independent: {$eq: 1}}}, " - " {$lookup: {from: 'foo', as: 'asField', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'asField', localField: 'y', foreignField: " + "'z'}}, " " {$unwind: {path: '$asField', preserveNullAndEmptyArrays: true}}, " " {$match: {'asField.value': {$gt: 0}}}]"; } @@ -473,13 +496,13 @@ class LookupWithUnwindPreservingNullAndEmptyArraysShouldNotCoalesce : public Bas class LookupDoesNotAbsorbElemMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: '$x'}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string outputPipeJson() { return "[{$lookup: { " - " from: 'foo', " + " from: 'lookupColl', " " as: 'x', " " localField: 'y', " " foreignField: 'z', " @@ -491,7 +514,7 @@ class LookupDoesNotAbsorbElemMatch : public Base { " {$match: {x: {$elemMatch: {a: 1}}}}]"; } string serializedPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x'}}, " " {$match: {x: {$elemMatch: {a: 1}}}}]"; } @@ -499,35 +522,35 @@ class LookupDoesNotAbsorbElemMatch : public Base { class LookupDoesSwapWithMatchOnLocalField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {y: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {y: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesSwapWithMatchOnFieldWithSameNameAsForeignField : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {z: {$eq: 3}}}]"; } string outputPipeJson() { return "[{$match: {z: {$eq: 3}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}]"; + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}]"; } }; class LookupDoesNotAbsorbUnwindOnSubfieldOfAsButStillMovesMatch : public Base { string inputPipeJson() { - return "[{$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + return "[{$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$unwind: {path: '$x.subfield'}}, " " {$match: {'independent': 2, 'x.dependent': 2}}]"; } string outputPipeJson() { return "[{$match: {'independent': {$eq: 2}}}, " - " {$lookup: {from: 'foo', as: 'x', localField: 'y', foreignField: 'z'}}, " + " {$lookup: {from: 'lookupColl', as: 'x', localField: 'y', foreignField: 'z'}}, " " {$match: {'x.dependent': {$eq: 2}}}, " " {$unwind: {path: '$x.subfield'}}]"; } @@ -639,58 +662,61 @@ class UnwindBeforeDoubleMatchShouldRepeatedlyOptimize : public Base { class GraphLookupShouldCoalesceWithUnwindOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$out'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: " - "false}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out'}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: true}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: true}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', preserveNullAndEmptyArrays: true}}]"; } }; class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d', unwinding: {preserveNullAndEmptyArrays: false, " - " includeArrayIndex: 'index'}}}]"; + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d', " + " unwinding: {preserveNullAndEmptyArrays: false, " + " includeArrayIndex: 'index'}}}]"; } string serializedPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', " " startWith: '$d'}}, " " {$unwind: {path: '$out', includeArrayIndex: 'index'}}]"; } @@ -698,14 +724,14 @@ class GraphLookupShouldCoalesceWithUnwindOnAsWithIncludeArrayIndex : public Base class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { string inputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: '$nottherightthing'}]"; } string outputPipeJson() final { - return "[{$graphLookup: {from: 'a', as: 'out', connectToField: 'b', connectFromField: 'c', " - " startWith: '$d'}}, " + return "[{$graphLookup: {from: 'lookupColl', as: 'out', connectToField: 'b', " + " connectFromField: 'c', startWith: '$d'}}, " " {$unwind: {path: '$nottherightthing'}}]"; } }; @@ -713,7 +739,7 @@ class GraphLookupShouldNotCoalesceWithUnwindNotOnAs : public Base { class GraphLookupShouldSwapWithMatch : public Base { string inputPipeJson() { return "[{$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -725,7 +751,7 @@ class GraphLookupShouldSwapWithMatch : public Base { string outputPipeJson() { return "[{$match: {independent: 'x'}}," " {$graphLookup: {" - " from: 'coll2'," + " from: 'lookupColl'," " as: 'results'," " connectToField: 'to'," " connectFromField: 'from'," @@ -863,7 +889,14 @@ public: rawPipeline.push_back(stageElem.embeddedObject()); } AggregationRequest request(NamespaceString("a.collection"), rawPipeline); - intrusive_ptr ctx = new ExpressionContext(&_opCtx, request); + intrusive_ptr ctx = + new ExpressionContextForTest(&_opCtx, request); + + // For $graphLookup and $lookup, we have to populate the resolvedNamespaces so that the + // operations will be able to have a resolved view definition. + NamespaceString lookupCollNs("a", "lookupColl"); + ctx->setResolvedNamespace(lookupCollNs, {lookupCollNs, std::vector{}}); + mergePipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); mergePipe->optimizePipeline(); @@ -1057,7 +1090,7 @@ namespace coalesceLookUpAndUnwind { class ShouldCoalesceUnwindOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same'}}" "]"; @@ -1066,14 +1099,14 @@ class ShouldCoalesceUnwindOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false}}}]"; } }; class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', preserveNullAndEmptyArrays: true}}" "]"; @@ -1082,14 +1115,14 @@ class ShouldCoalesceUnwindOnAsWithPreserveEmpty : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: true}}}]"; } }; class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$same', includeArrayIndex: 'index'}}" "]"; @@ -1098,7 +1131,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right', unwinding: {preserveNullAndEmptyArrays: false, includeArrayIndex: " "'index'}}}]"; } @@ -1106,7 +1139,7 @@ class ShouldCoalesceUnwindOnAsWithIncludeArrayIndex : public Base { class ShouldNotCoalesceUnwindNotOnAs : public Base { string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1115,7 +1148,7 @@ class ShouldNotCoalesceUnwindNotOnAs : public Base { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}" ",{$unwind: {path: '$from'}}" "]"; @@ -1170,14 +1203,14 @@ class LookUp : public needsPrimaryShardMergerBase { return true; } string inputPipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } string shardPipeJson() { return "[]"; } string mergePipeJson() { - return "[{$lookup: {from : 'coll2', as : 'same', localField: 'left', foreignField: " + return "[{$lookup: {from : 'lookupColl', as : 'same', localField: 'left', foreignField: " "'right'}}]"; } }; @@ -1192,7 +1225,7 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { OperationContextNoop _opCtx; const std::vector rawPipeline = { fromjson("{$geoNear: {distanceField: 'd', near: [0, 0], query: {a: 1}}}")}; - intrusive_ptr ctx = new ExpressionContext( + intrusive_ptr ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 1)); @@ -1201,28 +1234,13 @@ TEST(PipelineInitialSource, GeoNearInitialQuery) { TEST(PipelineInitialSource, MatchInitialQuery) { OperationContextNoop _opCtx; const std::vector rawPipeline = {fromjson("{$match: {'a': 4}}")}; - intrusive_ptr ctx = new ExpressionContext( + intrusive_ptr ctx = new ExpressionContextForTest( &_opCtx, AggregationRequest(NamespaceString("a.collection"), rawPipeline)); auto pipe = uassertStatusOK(Pipeline::parse(rawPipeline, ctx)); ASSERT_BSONOBJ_EQ(pipe->getInitialQuery(), BSON("a" << 4)); } -TEST(PipelineInitialSource, ParseCollation) { - QueryTestServiceContext serviceContext; - auto opCtx = serviceContext.makeOperationContext(); - - const BSONObj inputBson = - fromjson("{pipeline: [{$match: {a: 'abc'}}], collation: {locale: 'reverse'}}"); - auto request = AggregationRequest::parseFromBSON(NamespaceString("a.collection"), inputBson); - ASSERT_OK(request.getStatus()); - - intrusive_ptr ctx = new ExpressionContext(opCtx.get(), request.getValue()); - ASSERT(ctx->getCollator()); - CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); - ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->getCollator(), &collator)); -} - namespace Dependencies { using PipelineDependenciesTest = AggregationContextFixture; -- cgit v1.2.1