diff options
Diffstat (limited to 'src/mongo/db/pipeline')
55 files changed, 1937 insertions, 403 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 28d8aa0361f..f1a36c2bedd 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -37,7 +37,9 @@ env.Library( target='document_value', source=[ 'document.cpp', + 'document_comparator.cpp', 'value.cpp', + 'value_comparator.cpp', ], LIBDEPS=[ 'field_path', @@ -47,11 +49,29 @@ env.Library( ] ) +env.Library( + target='document_value_test_util', + source=[ + 'document_value_test_util.cpp', + ], + LIBDEPS=[ + '$BUILD_DIR/mongo/unittest/unittest', + 'document_value', + ], +) + env.CppUnitTest( target='document_value_test', - source='document_value_test.cpp', + source=[ + 'document_comparator_test.cpp', + 'document_value_test.cpp', + 'document_value_test_util_self_test.cpp', + 'value_comparator_test.cpp', + ], LIBDEPS=[ + '$BUILD_DIR/mongo/db/query/collation/collator_interface_mock', 'document_value', + 'document_value_test_util', ], ) @@ -75,7 +95,8 @@ env.CppUnitTest( target='aggregation_request_test', source='aggregation_request_test.cpp', LIBDEPS=[ - 'aggregation_request', + 'aggregation_request', + 'document_value_test_util', ], ) @@ -97,6 +118,7 @@ env.CppUnitTest( source='document_source_test.cpp', LIBDEPS=[ 'document_source', + 'document_value_test_util', '$BUILD_DIR/mongo/db/service_context', '$BUILD_DIR/mongo/util/clock_source_mock', '$BUILD_DIR/mongo/executor/thread_pool_task_executor', @@ -124,6 +146,7 @@ env.Library( LIBDEPS=[ 'dependencies', 'document_value', + 'expression_context', '$BUILD_DIR/mongo/util/summation', ] ) @@ -182,7 +205,6 @@ docSourceEnv.Library( 'dependencies', 'document_value', 'expression', - 'expression_context', 'parsed_aggregation_projection', '$BUILD_DIR/mongo/client/clientdriver', '$BUILD_DIR/mongo/db/bson/dotted_path_support', @@ -238,10 +260,11 @@ env.CppUnitTest( target='document_source_facet_test', source='document_source_facet_test.cpp', LIBDEPS=[ - 'document_source_facet', '$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init', '$BUILD_DIR/mongo/db/query/query_test_service_context', '$BUILD_DIR/mongo/db/service_context_noop_init', + 'document_source_facet', + 'document_value_test_util', ], ) @@ -249,9 +272,10 @@ env.CppUnitTest( target='tee_buffer_test', source='tee_buffer_test.cpp', LIBDEPS=[ - 'document_source_facet', '$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init', '$BUILD_DIR/mongo/db/service_context_noop_init', + 'document_source_facet', + 'document_value_test_util', ], ) @@ -260,6 +284,7 @@ env.CppUnitTest( source='expression_test.cpp', LIBDEPS=[ 'accumulator', + 'document_value_test_util', 'expression', ], ) @@ -269,6 +294,7 @@ env.CppUnitTest( source='accumulator_test.cpp', LIBDEPS=[ 'accumulator', + 'document_value_test_util', ], ) @@ -276,12 +302,13 @@ env.CppUnitTest( target='pipeline_test', source='pipeline_test.cpp', LIBDEPS=[ - 'pipeline', '$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init', '$BUILD_DIR/mongo/db/query/collation/collator_interface_mock', '$BUILD_DIR/mongo/db/query/query_test_service_context', '$BUILD_DIR/mongo/db/service_context', '$BUILD_DIR/mongo/db/service_context_noop_init', + 'document_value_test_util', + 'pipeline', ], ) @@ -314,6 +341,7 @@ env.CppUnitTest( target='parsed_exclusion_projection_test', source='parsed_exclusion_projection_test.cpp', LIBDEPS=[ + 'document_value_test_util', 'parsed_aggregation_projection', ], ) @@ -330,6 +358,7 @@ env.CppUnitTest( target='parsed_inclusion_projection_test', source='parsed_inclusion_projection_test.cpp', LIBDEPS=[ + 'document_value_test_util', 'parsed_aggregation_projection', ], ) diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h index eb09459f942..81b84bf1982 100644 --- a/src/mongo/db/pipeline/accumulator.h +++ b/src/mongo/db/pipeline/accumulator.h @@ -31,12 +31,15 @@ #include "mongo/platform/basic.h" #include <boost/intrusive_ptr.hpp> +#include <boost/optional.hpp> #include <unordered_set> #include <vector> #include "mongo/base/init.h" #include "mongo/bson/bsontypes.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/value.h" +#include "mongo/db/pipeline/value_comparator.h" #include "mongo/stdx/functional.h" #include "mongo/util/summation.h" @@ -108,12 +111,35 @@ public: return false; } + /** + * Injects the ExpressionContext so that it may be used during evaluation of the Accumulator. + * Construction of accumulators is done at parse time, but the ExpressionContext isn't finalized + * until later, at which point it is injected using this method. + */ + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { + _expCtx = expCtx; + doInjectExpressionContext(); + } + protected: /// Update subclass's internal state based on input virtual void processInternal(const Value& input, bool merging) = 0; + /** + * Accumulators which need to update their internal state when attaching to a new + * ExpressionContext should override this method. + */ + virtual void doInjectExpressionContext() {} + + const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { + return _expCtx; + } + /// subclasses are expected to update this as necessary int _memUsageBytes = 0; + +private: + boost::intrusive_ptr<ExpressionContext> _expCtx; }; @@ -136,9 +162,13 @@ public: return true; } + void doInjectExpressionContext() final; + private: - typedef std::unordered_set<Value, Value::Hash> SetType; - SetType set; + // We use boost::optional to defer initialization until the ExpressionContext containing the + // correct comparator is injected, since this set must use the comparator's definition of + // equality. + boost::optional<ValueUnorderedSet> _set; }; diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index 313624bfdf0..7c780fcc7ee 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -46,7 +46,7 @@ const char* AccumulatorAddToSet::getOpName() const { void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { if (!merging) { if (!input.missing()) { - bool inserted = set.insert(input).second; + bool inserted = _set->insert(input).second; if (inserted) { _memUsageBytes += input.getApproximateSize(); } @@ -60,7 +60,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { const vector<Value>& array = input.getArray(); for (size_t i = 0; i < array.size(); i++) { - bool inserted = set.insert(array[i]).second; + bool inserted = _set->insert(array[i]).second; if (inserted) { _memUsageBytes += array[i].getApproximateSize(); } @@ -69,7 +69,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) { } Value AccumulatorAddToSet::getValue(bool toBeMerged) const { - return Value(vector<Value>(set.begin(), set.end())); + return Value(vector<Value>(_set->begin(), _set->end())); } AccumulatorAddToSet::AccumulatorAddToSet() { @@ -77,11 +77,16 @@ AccumulatorAddToSet::AccumulatorAddToSet() { } void AccumulatorAddToSet::reset() { - SetType().swap(set); + _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet(); _memUsageBytes = sizeof(*this); } intrusive_ptr<Accumulator> AccumulatorAddToSet::create() { return new AccumulatorAddToSet(); } + +void AccumulatorAddToSet::doInjectExpressionContext() { + _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet(); } + +} // namespace mongo diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp index e354d71d8f9..8fad2b0095d 100644 --- a/src/mongo/db/pipeline/accumulator_test.cpp +++ b/src/mongo/db/pipeline/accumulator_test.cpp @@ -30,6 +30,7 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/dbtests/dbtests.h" @@ -57,7 +58,7 @@ static void assertExpectedResults( accum->process(val, false); } Value result = accum->getValue(false); - ASSERT_EQUALS(op.second, result); + ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); } @@ -70,7 +71,7 @@ static void assertExpectedResults( } accum->process(shard->getValue(true), true); Value result = accum->getValue(false); - ASSERT_EQUALS(op.second, result); + ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); } @@ -83,7 +84,7 @@ static void assertExpectedResults( accum->process(shard->getValue(true), true); } Value result = accum->getValue(false); - ASSERT_EQUALS(op.second, result); + ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); } } catch (...) { diff --git a/src/mongo/db/pipeline/aggregation_request_test.cpp b/src/mongo/db/pipeline/aggregation_request_test.cpp index e7dc1dd0c4e..b9fe685961f 100644 --- a/src/mongo/db/pipeline/aggregation_request_test.cpp +++ b/src/mongo/db/pipeline/aggregation_request_test.cpp @@ -36,6 +36,7 @@ #include "mongo/db/catalog/document_validation.h" #include "mongo/db/namespace_string.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -75,7 +76,7 @@ TEST(AggregationRequestTest, ShouldOnlySerializeRequiredFieldsIfNoOptionalFields auto expectedSerialization = Document{{AggregationRequest::kCommandName, nss.coll()}, {AggregationRequest::kPipelineName, Value(std::vector<Value>{})}}; - ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization); + ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization); } TEST(AggregationRequestTest, ShouldNotSerializeOptionalValuesIfEquivalentToDefault) { @@ -90,7 +91,7 @@ TEST(AggregationRequestTest, ShouldNotSerializeOptionalValuesIfEquivalentToDefau auto expectedSerialization = Document{{AggregationRequest::kCommandName, nss.coll()}, {AggregationRequest::kPipelineName, Value(std::vector<Value>{})}}; - ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization); + ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization); } TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) { @@ -112,7 +113,7 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) { {AggregationRequest::kFromRouterName, true}, {bypassDocumentValidationCommandOption(), true}, {AggregationRequest::kCollationName, collationObj}}; - ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization); + ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization); } TEST(AggregationRequestTest, ShouldSetBatchSizeToDefaultOnEmptyCursorObject) { diff --git a/src/mongo/db/pipeline/document.cpp b/src/mongo/db/pipeline/document.cpp index 7c1d04f1a9f..df2eb8facdf 100644 --- a/src/mongo/db/pipeline/document.cpp +++ b/src/mongo/db/pipeline/document.cpp @@ -381,7 +381,9 @@ void Document::hash_combine(size_t& seed) const { } } -int Document::compare(const Document& rL, const Document& rR) { +int Document::compare(const Document& rL, + const Document& rR, + const StringData::ComparatorInterface* stringComparator) { DocumentStorageIterator lIt = rL.storage().iterator(); DocumentStorageIterator rIt = rR.storage().iterator(); @@ -410,7 +412,7 @@ int Document::compare(const Document& rL, const Document& rR) { if (nameCmp) return nameCmp; // field names are unequal - const int valueCmp = Value::compare(lField.val, rField.val); + const int valueCmp = Value::compare(lField.val, rField.val, stringComparator); if (valueCmp) return valueCmp; // fields are unequal diff --git a/src/mongo/db/pipeline/document.h b/src/mongo/db/pipeline/document.h index 840c81315f0..0a0220964bd 100644 --- a/src/mongo/db/pipeline/document.h +++ b/src/mongo/db/pipeline/document.h @@ -33,6 +33,7 @@ #include <boost/functional/hash.hpp> #include <boost/intrusive_ptr.hpp> +#include "mongo/base/string_data.h" #include "mongo/bson/util/builder.h" namespace mongo { @@ -66,6 +67,28 @@ class Position; */ class Document { public: + /** + * Operator overloads for relops return a DeferredComparison which can subsequently be evaluated + * by a DocumentComparator. + */ + struct DeferredComparison { + enum class Type { + kLT, + kLTE, + kEQ, + kGT, + kGTE, + kNE, + }; + + DeferredComparison(Type type, const Document& lhs, const Document& rhs) + : type(type), lhs(lhs), rhs(rhs) {} + + Type type; + const Document& lhs; + const Document& rhs; + }; + /// Empty Document (does no allocation) Document() {} @@ -130,20 +153,29 @@ public: */ size_t getApproximateSize() const; - /** Compare two documents. + /** + * Compare two documents. Most callers should prefer using DocumentComparator instead. See + * document_comparator.h for details. * * BSON document field order is significant, so this just goes through * the fields in order. The comparison is done in roughly the same way * as strings are compared, but comparing one field at a time instead * of one character at a time. * + * Pass a non-null StringData::ComparatorInterface if special string comparison semantics are + * required. If the comparator is null, then a simple binary compare is used for strings. This + * comparator is only used for string *values*; field names are always compared using simple + * binary compare. + * * Note: This does not consider metadata when comparing documents. * * @returns an integer less than zero, zero, or an integer greater than * zero, depending on whether lhs < rhs, lhs == rhs, or lhs > rhs * Warning: may return values other than -1, 0, or 1 */ - static int compare(const Document& lhs, const Document& rhs); + static int compare(const Document& lhs, + const Document& rhs, + const StringData::ComparatorInterface* stringComparator = nullptr); std::string toString() const; @@ -246,25 +278,38 @@ private: boost::intrusive_ptr<const DocumentStorage> _storage; }; -inline bool operator==(const Document& l, const Document& r) { - return Document::compare(l, r) == 0; +// +// Comparison API. +// +// Document instances can be compared either using Document::compare() or via operator overloads. +// Most callers should prefer operator overloads. Note that the operator overloads return a +// DeferredComparison, which must be subsequently evaluated by a DocumentComparator. See +// document_comparator.h for details. +// + +inline Document::DeferredComparison operator==(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kEQ, lhs, rhs); } -inline bool operator!=(const Document& l, const Document& r) { - return Document::compare(l, r) != 0; -} -inline bool operator<(const Document& l, const Document& r) { - return Document::compare(l, r) < 0; + +inline Document::DeferredComparison operator!=(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kNE, lhs, rhs); } -inline bool operator<=(const Document& l, const Document& r) { - return Document::compare(l, r) <= 0; + +inline Document::DeferredComparison operator<(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kLT, lhs, rhs); } -inline bool operator>(const Document& l, const Document& r) { - return Document::compare(l, r) > 0; + +inline Document::DeferredComparison operator<=(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kLTE, lhs, rhs); } -inline bool operator>=(const Document& l, const Document& r) { - return Document::compare(l, r) >= 0; + +inline Document::DeferredComparison operator>(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kGT, lhs, rhs); } +inline Document::DeferredComparison operator>=(const Document& lhs, const Document& rhs) { + return Document::DeferredComparison(Document::DeferredComparison::Type::kGTE, lhs, rhs); +} /** This class is returned by MutableDocument to allow you to modify its values. * You are not allowed to hold variables of this type (enforced by the type system). diff --git a/src/mongo/db/pipeline/document_comparator.cpp b/src/mongo/db/pipeline/document_comparator.cpp new file mode 100644 index 00000000000..a11eb092122 --- /dev/null +++ b/src/mongo/db/pipeline/document_comparator.cpp @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/document_comparator.h" + +#include "mongo/util/assert_util.h" + +namespace mongo { + +bool DocumentComparator::evaluate(Document::DeferredComparison deferredComparison) const { + int cmp = Document::compare(deferredComparison.lhs, deferredComparison.rhs, _stringComparator); + switch (deferredComparison.type) { + case Document::DeferredComparison::Type::kLT: + return cmp < 0; + case Document::DeferredComparison::Type::kLTE: + return cmp <= 0; + case Document::DeferredComparison::Type::kEQ: + return cmp == 0; + case Document::DeferredComparison::Type::kGTE: + return cmp >= 0; + case Document::DeferredComparison::Type::kGT: + return cmp > 0; + case Document::DeferredComparison::Type::kNE: + return cmp != 0; + } + + MONGO_UNREACHABLE; +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_comparator.h b/src/mongo/db/pipeline/document_comparator.h new file mode 100644 index 00000000000..91c1e08da58 --- /dev/null +++ b/src/mongo/db/pipeline/document_comparator.h @@ -0,0 +1,59 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/base/string_data.h" +#include "mongo/db/pipeline/document.h" + +namespace mongo { + +class DocumentComparator { +public: + /** + * Constructs a document comparator with simple comparison semantics. + */ + DocumentComparator() = default; + + /** + * Constructs a document comparator with special string comparison semantics. + */ + DocumentComparator(const StringData::ComparatorInterface* stringComparator) + : _stringComparator(stringComparator) {} + + /** + * Evaluates a deferred comparison object that was generated by invoking one of the comparison + * operators on the Document class. + */ + bool evaluate(Document::DeferredComparison deferredComparison) const; + +private: + const StringData::ComparatorInterface* _stringComparator = nullptr; +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_comparator_test.cpp b/src/mongo/db/pipeline/document_comparator_test.cpp new file mode 100644 index 00000000000..40229e982c1 --- /dev/null +++ b/src/mongo/db/pipeline/document_comparator_test.cpp @@ -0,0 +1,169 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/document_comparator.h" + +#include "mongo/db/query/collation/collator_interface_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +TEST(DocumentComparatorTest, EqualToEvaluatesCorrectly) { + const Document doc1{{"foo", "bar"}}; + const Document doc2{{"foo", "bar"}}; + const Document doc3{{"foo", "baz"}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 == doc2)); + ASSERT_FALSE(DocumentComparator().evaluate(doc1 == doc3)); +} + +TEST(DocumentComparatorTest, EqualToEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + const Document doc1{{"foo", "abc"}}; + const Document doc2{{"foo", "def"}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2)); +} + +TEST(DocumentComparatorTest, NotEqualEvaluatesCorrectly) { + const Document doc1{{"foo", "bar"}}; + const Document doc2{{"foo", "bar"}}; + const Document doc3{{"foo", "baz"}}; + ASSERT_FALSE(DocumentComparator().evaluate(doc1 != doc2)); + ASSERT_TRUE(DocumentComparator().evaluate(doc1 != doc3)); +} + +TEST(DocumentComparatorTest, NotEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + const Document doc1{{"foo", "abc"}}; + const Document doc2{{"foo", "def"}}; + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc1 != doc2)); +} + +TEST(DocumentComparatorTest, LessThanEvaluatesCorrectly) { + const Document doc1{{"foo", "a"}}; + const Document doc2{{"foo", "b"}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 < doc2)); + ASSERT_FALSE(DocumentComparator().evaluate(doc2 < doc1)); +} + +TEST(DocumentComparatorTest, LessThanEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + const Document doc1{{"foo", "za"}}; + const Document doc2{{"foo", "yb"}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 < doc2)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc2 < doc1)); +} + +TEST(DocumentComparatorTest, LessThanOrEqualEvaluatesCorrectly) { + const Document doc1{{"foo", "a"}}; + const Document doc2{{"foo", "a"}}; + const Document doc3{{"foo", "b"}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 <= doc2)); + ASSERT_TRUE(DocumentComparator().evaluate(doc2 <= doc1)); + ASSERT_TRUE(DocumentComparator().evaluate(doc1 <= doc3)); + ASSERT_FALSE(DocumentComparator().evaluate(doc3 <= doc1)); +} + +TEST(DocumentComparatorTest, LessThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + const Document doc1{{"foo", "za"}}; + const Document doc2{{"foo", "za"}}; + const Document doc3{{"foo", "yb"}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 <= doc2)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 <= doc1)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 <= doc3)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 <= doc1)); +} + +TEST(DocumentComparatorTest, GreaterThanEvaluatesCorrectly) { + const Document doc1{{"foo", "b"}}; + const Document doc2{{"foo", "a"}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 > doc2)); + ASSERT_FALSE(DocumentComparator().evaluate(doc2 > doc1)); +} + +TEST(DocumentComparatorTest, GreaterThanEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + const Document doc1{{"foo", "yb"}}; + const Document doc2{{"foo", "za"}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 > doc2)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc2 > doc1)); +} + +TEST(DocumentComparatorTest, GreaterThanOrEqualEvaluatesCorrectly) { + const Document doc1{{"foo", "b"}}; + const Document doc2{{"foo", "b"}}; + const Document doc3{{"foo", "a"}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 >= doc2)); + ASSERT_TRUE(DocumentComparator().evaluate(doc2 >= doc1)); + ASSERT_TRUE(DocumentComparator().evaluate(doc1 >= doc3)); + ASSERT_FALSE(DocumentComparator().evaluate(doc3 >= doc1)); +} + +TEST(DocumentComparatorTest, GreaterThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + const Document doc1{{"foo", "yb"}}; + const Document doc2{{"foo", "yb"}}; + const Document doc3{{"foo", "za"}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 >= doc2)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 >= doc1)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 >= doc3)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 >= doc1)); +} + +TEST(DocumentComparatorTest, EqualToEvaluatesCorrectlyWithNumbers) { + const Document doc1{{"foo", 88}}; + const Document doc2{{"foo", 88}}; + const Document doc3{{"foo", 99}}; + ASSERT_TRUE(DocumentComparator().evaluate(doc1 == doc2)); + ASSERT_FALSE(DocumentComparator().evaluate(doc1 == doc3)); +} + +TEST(DocumentComparatorTest, NestedObjectEqualityRespectsCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + const Document doc1{{"foo", Document{{"foo", "abc"}}}}; + const Document doc2{{"foo", Document{{"foo", "def"}}}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 == doc1)); +} + +TEST(DocumentComparatorTest, NestedArrayEqualityRespectsCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + const Document doc1{{"foo", std::vector<Value>{Value("a"), Value("b")}}}; + const Document doc2{{"foo", std::vector<Value>{Value("c"), Value("d")}}}; + const Document doc3{{"foo", std::vector<Value>{Value("c"), Value("d"), Value("e")}}}; + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2)); + ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 == doc1)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc1 == doc3)); + ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 == doc1)); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_source.h b/src/mongo/db/pipeline/document_source.h index 0ffded07edb..f6c8aa91545 100644 --- a/src/mongo/db/pipeline/document_source.h +++ b/src/mongo/db/pipeline/document_source.h @@ -30,6 +30,7 @@ #include "mongo/platform/basic.h" +#include <boost/optional.hpp> #include <deque> #include <list> #include <string> @@ -52,6 +53,7 @@ #include "mongo/db/pipeline/parsed_aggregation_projection.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/pipeline/value.h" +#include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/plan_summary_stats.h" #include "mongo/db/sorter/sorter.h" #include "mongo/stdx/functional.h" @@ -230,6 +232,18 @@ public: virtual void reattachToOperationContext(OperationContext* opCtx) {} /** + * Injects a new ExpressionContext into this DocumentSource and propagates the ExpressionContext + * to all child expressions, accumulators, etc. + * + * Stages which require work to propagate the ExpressionContext to their private execution + * machinery should override doInjectExpressionContext(). + */ + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { + pExpCtx = expCtx; + doInjectExpressionContext(); + } + + /** * Create a DocumentSource pipeline stage from 'stageObj'. */ static std::vector<boost::intrusive_ptr<DocumentSource>> parse( @@ -264,6 +278,15 @@ protected: */ explicit DocumentSource(const boost::intrusive_ptr<ExpressionContext>& pExpCtx); + /** + * DocumentSources which need to update their internal state when attaching to a new + * ExpressionContext should override this method. + * + * Any stage subclassing from DocumentSource should override this method if it contains + * expressions or accumulators which need to attach to the newly injected ExpressionContext. + */ + virtual void doInjectExpressionContext() {} + /* Most DocumentSources have an underlying source they get their data from. This is a convenience for them. @@ -491,6 +514,9 @@ public: const PlanSummaryStats& getPlanSummaryStats() const; +protected: + void doInjectExpressionContext() final; + private: DocumentSourceCursor(const std::string& ns, const std::shared_ptr<PlanExecutor>& exec, @@ -524,7 +550,7 @@ private: class DocumentSourceGroup final : public DocumentSource, public SplittableDocumentSource { public: using Accumulators = std::vector<boost::intrusive_ptr<Accumulator>>; - using GroupsMap = std::unordered_map<Value, Accumulators, Value::Hash>; + using GroupsMap = ValueUnorderedMap<Accumulators>; // Virtuals from DocumentSource. boost::intrusive_ptr<DocumentSource> optimize() final; @@ -573,6 +599,9 @@ public: boost::intrusive_ptr<DocumentSource> getShardSource() final; boost::intrusive_ptr<DocumentSource> getMergeSource() final; +protected: + void doInjectExpressionContext() final; + private: explicit DocumentSourceGroup(const boost::intrusive_ptr<ExpressionContext>& pExpCtx); @@ -647,7 +676,10 @@ private: Value _currentId; Accumulators _currentAccumulators; - GroupsMap groups; + // We use boost::optional to defer initialization until the ExpressionContext containing the + // correct comparator is injected, since the groups must be built using the comparator's + // definition of equality. + boost::optional<GroupsMap> _groups; bool _spilled; @@ -789,6 +821,8 @@ public: const std::string& path, boost::intrusive_ptr<ExpressionContext> expCtx); + void doInjectExpressionContext(); + private: DocumentSourceMatch(const BSONObj& query, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); @@ -914,11 +948,17 @@ public: return this; } + void doInjectExpressionContext() override { + isExpCtxInjected = true; + } + // Return documents from front of queue. std::deque<Document> queue; + bool isDisposed = false; bool isDetachedFromOpCtx = false; bool isOptimized = false; + bool isExpCtxInjected = false; BSONObjSet sorts; }; @@ -1001,6 +1041,8 @@ public: */ boost::intrusive_ptr<DocumentSource> optimize() final; + void doInjectExpressionContext() final; + /** * Parse the projection from the user-supplied BSON. */ @@ -1028,6 +1070,8 @@ public: Pipeline::SourceContainer::iterator optimizeAt(Pipeline::SourceContainer::iterator itr, Pipeline::SourceContainer* container) final; + void doInjectExpressionContext() final; + static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -1063,6 +1107,8 @@ public: return _size; } + void doInjectExpressionContext() final; + static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx); @@ -1086,6 +1132,8 @@ public: Value serialize(bool explain = false) const final; GetDepsReturn getDependencies(DepsTracker* deps) const final; + void doInjectExpressionContext() final; + static boost::intrusive_ptr<DocumentSourceSampleFromRandomCursor> create( const boost::intrusive_ptr<ExpressionContext>& expCtx, long long size, @@ -1111,8 +1159,9 @@ private: std::string _idField; // Keeps track of the documents that have been returned, since a random cursor is allowed to - // return duplicates. - ValueSet _seenDocs; + // return duplicates. We use boost::optional to defer initialization until the ExpressionContext + // containing the correct comparator is injected. + boost::optional<ValueUnorderedSet> _seenDocs; // The approximate number of documents in the collection (includes orphans). const long long _nDocsInColl; @@ -1188,6 +1237,7 @@ private: long long count; }; +// TODO SERVER-23349: Make aggregation sort respect the collation. class DocumentSourceSort final : public DocumentSource, public SplittableDocumentSource { public: // virtuals from DocumentSource @@ -1473,6 +1523,7 @@ private: std::unique_ptr<Unwinder> _unwinder; }; +// TODO SERVER-23349: Make geoNear agg stage respect the collation. class DocumentSourceGeoNear : public DocumentSourceNeedsMongod, public SplittableDocumentSource { public: static const long long kDefaultLimit; @@ -1543,6 +1594,8 @@ private: /** * Queries separate collection for equality matches with documents in the pipeline collection. * Adds matching documents to a new array field in the input document. + * + * TODO SERVER-23349: Make $lookup respect the collation. */ class DocumentSourceLookUp final : public DocumentSourceNeedsMongod, public SplittableDocumentSource { @@ -1620,6 +1673,7 @@ private: boost::optional<Document> _input; }; +// TODO SERVER-23349: Make $graphLookup respect the collation. class DocumentSourceGraphLookUp final : public DocumentSourceNeedsMongod { public: boost::optional<Document> getNext() final; @@ -1647,6 +1701,8 @@ public: collections->push_back(_from); } + void doInjectExpressionContext() final; + static boost::intrusive_ptr<DocumentSource> createFromBson( BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx); @@ -1698,7 +1754,7 @@ private: * Updates '_cache' with 'result' appropriately, given that 'result' was retrieved when querying * for 'queried'. */ - void addToCache(const BSONObj& result, const unordered_set<Value, Value::Hash>& queried); + void addToCache(const BSONObj& result, const ValueUnorderedSet& queried); /** * Assert that '_visited' and '_frontier' have not exceeded the maximum meory usage, and then @@ -1731,11 +1787,15 @@ private: size_t _frontierUsageBytes = 0; // Only used during the breadth-first search, tracks the set of values on the current frontier. - std::unordered_set<Value, Value::Hash> _frontier; + // We use boost::optional to defer initialization until the ExpressionContext containing the + // correct comparator is injected. + boost::optional<ValueUnorderedSet> _frontier; // Tracks nodes that have been discovered for a given input. Keys are the '_id' value of the - // document from the foreign collection, value is the document itself. - std::unordered_map<Value, BSONObj, Value::Hash> _visited; + // document from the foreign collection, value is the document itself. We use boost::optional + // to defer initialization until the ExpressionContext containing the correct comparator is + // injected. + boost::optional<ValueUnorderedMap<BSONObj>> _visited; // Caches query results to avoid repeating any work. This structure is maintained across calls // to getNext(). diff --git a/src/mongo/db/pipeline/document_source_bucket.cpp b/src/mongo/db/pipeline/document_source_bucket.cpp index 6018402113b..d895b532156 100644 --- a/src/mongo/db/pipeline/document_source_bucket.cpp +++ b/src/mongo/db/pipeline/document_source_bucket.cpp @@ -121,6 +121,8 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( << typeName(upper.getType()) << ".", lowerCanonicalType == upperCanonicalType); + // TODO SERVER-25038: This check must be deferred so that it respects the final + // collator, which is not necessarily the same as the collator at parse time. uassert(40194, str::stream() << "The 'boundaries' option to $bucket must be sorted, but elements " @@ -132,7 +134,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( << " is not less than " << upper.toString() << ").", - lower < upper); + pExpCtx->getValueComparator().evaluate(lower < upper)); } } else if ("default" == argName) { // If there is a default, make sure that it parses to a constant expression then add @@ -173,11 +175,15 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson( Value upperValue = boundaryValues.back(); if (canonicalizeBSONType(defaultValue.getType()) == canonicalizeBSONType(lowerValue.getType())) { - // If the default has the same canonical type as the bucket's boundaries, then make sure - // the default is less than the lowest boundary or greater than or equal to the highest + // If the default has the same canonical type as the bucket's boundaries, then make sure the + // default is less than the lowest boundary or greater than or equal to the highest // boundary. - const bool hasValidDefault = - defaultValue < lowerValue || Value::compare(defaultValue, upperValue) >= 0; + // + // TODO SERVER-25038: This check must be deferred so that it respects the final collator, + // which is not necessarily the same as the collator at parse time. + const auto& valueCmp = pExpCtx->getValueComparator(); + const bool hasValidDefault = valueCmp.evaluate(defaultValue < lowerValue) || + valueCmp.evaluate(defaultValue >= upperValue); uassert(40199, "The $bucket 'default' field must be less than the lowest boundary or greater than " "or equal to the highest boundary.", diff --git a/src/mongo/db/pipeline/document_source_cursor.cpp b/src/mongo/db/pipeline/document_source_cursor.cpp index c9a0d0673e3..aef76129244 100644 --- a/src/mongo/db/pipeline/document_source_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_cursor.cpp @@ -221,6 +221,12 @@ Value DocumentSourceCursor::serialize(bool explain) const { return Value(DOC(getSourceName() << out.freezeToValue())); } +void DocumentSourceCursor::doInjectExpressionContext() { + if (_limit) { + _limit->injectExpressionContext(pExpCtx); + } +} + DocumentSourceCursor::DocumentSourceCursor(const string& ns, const std::shared_ptr<PlanExecutor>& exec, const intrusive_ptr<ExpressionContext>& pCtx) @@ -239,7 +245,9 @@ intrusive_ptr<DocumentSourceCursor> DocumentSourceCursor::create( const string& ns, const std::shared_ptr<PlanExecutor>& exec, const intrusive_ptr<ExpressionContext>& pExpCtx) { - return new DocumentSourceCursor(ns, exec, pExpCtx); + intrusive_ptr<DocumentSourceCursor> source(new DocumentSourceCursor(ns, exec, pExpCtx)); + source->injectExpressionContext(pExpCtx); + return source; } void DocumentSourceCursor::setProjection(const BSONObj& projection, diff --git a/src/mongo/db/pipeline/document_source_facet.cpp b/src/mongo/db/pipeline/document_source_facet.cpp index 45c24a1480b..0c3e2981d79 100644 --- a/src/mongo/db/pipeline/document_source_facet.cpp +++ b/src/mongo/db/pipeline/document_source_facet.cpp @@ -122,6 +122,12 @@ intrusive_ptr<DocumentSource> DocumentSourceFacet::optimize() { return this; } +void DocumentSourceFacet::doInjectExpressionContext() { + for (auto&& facet : _facetPipelines) { + facet.second->injectExpressionContext(pExpCtx); + } +} + void DocumentSourceFacet::doInjectMongodInterface(std::shared_ptr<MongodInterface> mongod) { for (auto&& facet : _facetPipelines) { for (auto&& stage : facet.second->getSources()) { diff --git a/src/mongo/db/pipeline/document_source_facet.h b/src/mongo/db/pipeline/document_source_facet.h index 583c09ea4b5..a6ac631c327 100644 --- a/src/mongo/db/pipeline/document_source_facet.h +++ b/src/mongo/db/pipeline/document_source_facet.h @@ -77,6 +77,11 @@ public: boost::intrusive_ptr<DocumentSource> optimize() final; /** + * Injects the expression context into inner pipelines. + */ + void doInjectExpressionContext() final; + + /** * Takes a union of all sub-pipelines, and adds them to 'deps'. */ GetDepsReturn getDependencies(DepsTracker* deps) const final; diff --git a/src/mongo/db/pipeline/document_source_facet_test.cpp b/src/mongo/db/pipeline/document_source_facet_test.cpp index 7deb8127949..519b8d0a222 100644 --- a/src/mongo/db/pipeline/document_source_facet_test.cpp +++ b/src/mongo/db/pipeline/document_source_facet_test.cpp @@ -38,6 +38,7 @@ #include "mongo/bson/json.h" #include "mongo/db/pipeline/aggregation_context_fixture.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/util/assert_util.h" namespace mongo { @@ -194,7 +195,7 @@ TEST_F(DocumentSourceFacetTest, SingleFacetShouldReceiveAllDocuments) { auto output = facetStage->getNext(); ASSERT_TRUE(output); - ASSERT_EQ(*output, Document(fromjson("{results: [{_id: 0}, {_id: 1}]}"))); + ASSERT_DOCUMENT_EQ(*output, Document(fromjson("{results: [{_id: 0}, {_id: 1}]}"))); // Should be exhausted now. ASSERT_FALSE(facetStage->getNext()); @@ -224,8 +225,8 @@ TEST_F(DocumentSourceFacetTest, MultipleFacetsShouldSeeTheSameDocuments) { std::vector<Value> expectedOutputs(inputs.begin(), inputs.end()); ASSERT_TRUE(output); ASSERT_EQ((*output).size(), 2UL); - ASSERT_EQ((*output)["first"], Value(expectedOutputs)); - ASSERT_EQ((*output)["second"], Value(expectedOutputs)); + ASSERT_VALUE_EQ((*output)["first"], Value(expectedOutputs)); + ASSERT_VALUE_EQ((*output)["second"], Value(expectedOutputs)); // Should be exhausted now. ASSERT_FALSE(facetStage->getNext()); @@ -248,7 +249,7 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToEvaluateMultipleStagesWithinOneSub auto output = facetStage->getNext(); ASSERT_TRUE(output); - ASSERT_EQ(*output, Document(fromjson("{subPipe: [{_id: 0}, {_id: 1}]}"))); + ASSERT_DOCUMENT_EQ(*output, Document(fromjson("{subPipe: [{_id: 0}, {_id: 1}]}"))); } // @@ -287,10 +288,10 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToReParseSerializedStage) { // Should have two fields: "skippedOne" and "skippedTwo". auto serializedStage = serialization[0].getDocument()["$facet"].getDocument(); ASSERT_EQ(serializedStage.size(), 2UL); - ASSERT_EQ(serializedStage["skippedOne"], - Value(std::vector<Value>{Value(Document{{"$skip", 1}})})); - ASSERT_EQ(serializedStage["skippedTwo"], - Value(std::vector<Value>{Value(Document{{"$skip", 2}})})); + ASSERT_VALUE_EQ(serializedStage["skippedOne"], + Value(std::vector<Value>{Value(Document{{"$skip", 1}})})); + ASSERT_VALUE_EQ(serializedStage["skippedTwo"], + Value(std::vector<Value>{Value(Document{{"$skip", 2}})})); auto serializedBson = serialization[0].getDocument().toBson(); auto roundTripped = DocumentSourceFacet::createFromBson(serializedBson.firstElement(), ctx); @@ -300,7 +301,7 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToReParseSerializedStage) { roundTripped->serializeToArray(newSerialization); ASSERT_EQ(newSerialization.size(), 1UL); - ASSERT_EQ(newSerialization[0], serialization[0]); + ASSERT_VALUE_EQ(newSerialization[0], serialization[0]); } TEST_F(DocumentSourceFacetTest, ShouldOptimizeInnerPipelines) { diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp index 5f4e6ae0552..0107b41087c 100644 --- a/src/mongo/db/pipeline/document_source_geo_near.cpp +++ b/src/mongo/db/pipeline/document_source_geo_near.cpp @@ -165,7 +165,9 @@ void DocumentSourceGeoNear::runCommand() { intrusive_ptr<DocumentSourceGeoNear> DocumentSourceGeoNear::create( const intrusive_ptr<ExpressionContext>& pCtx) { - return new DocumentSourceGeoNear(pCtx); + intrusive_ptr<DocumentSourceGeoNear> source(new DocumentSourceGeoNear(pCtx)); + source->injectExpressionContext(pCtx); + return source; } intrusive_ptr<DocumentSource> DocumentSourceGeoNear::createFromBson( diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.cpp b/src/mongo/db/pipeline/document_source_graph_lookup.cpp index 0c5ce1d4c21..01887184512 100644 --- a/src/mongo/db/pipeline/document_source_graph_lookup.cpp +++ b/src/mongo/db/pipeline/document_source_graph_lookup.cpp @@ -73,11 +73,11 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNext() { performSearch(); std::vector<Value> results; - while (!_visited.empty()) { + while (!_visited->empty()) { // Remove elements one at a time to avoid consuming more memory. - auto it = _visited.begin(); + auto it = _visited->begin(); results.push_back(Value(it->second)); - _visited.erase(it); + _visited->erase(it); } MutableDocument output(*_input); @@ -85,7 +85,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNext() { _visitedUsageBytes = 0; - invariant(_visited.empty()); + invariant(_visited->empty()); return output.freeze(); } @@ -96,7 +96,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() { // If the unwind is not preserving empty arrays, we might have to process multiple inputs before // we get one that will produce an output. while (true) { - if (_visited.empty()) { + if (_visited->empty()) { // No results are left for the current input, so we should move on to the next one and // perform a new search. if (!(_input = pSource->getNext())) { @@ -110,7 +110,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() { } MutableDocument unwound(*_input); - if (_visited.empty()) { + if (_visited->empty()) { if ((*_unwind)->preserveNullAndEmptyArrays()) { // Since "preserveNullAndEmptyArrays" was specified, output a document even though // we had no result. @@ -124,13 +124,13 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() { continue; } } else { - auto it = _visited.begin(); + auto it = _visited->begin(); unwound.setNestedField(_as, Value(it->second)); if (indexPath) { unwound.setNestedField(*indexPath, Value(_outputIndex)); ++_outputIndex; } - _visited.erase(it); + _visited->erase(it); } return unwound.freeze(); @@ -139,8 +139,8 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() { void DocumentSourceGraphLookUp::dispose() { _cache.clear(); - _frontier.clear(); - _visited.clear(); + _frontier->clear(); + _visited->clear(); pSource->dispose(); } @@ -154,8 +154,8 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { BSONObjSet cached; auto query = constructQuery(&cached); - std::unordered_set<Value, Value::Hash> queried; - _frontier.swap(queried); + ValueUnorderedSet queried = pExpCtx->getValueComparator().makeUnorderedValueSet(); + _frontier->swap(queried); _frontierUsageBytes = 0; // Process cached values, populating '_frontier' for the next iteration of search. @@ -186,7 +186,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() { } while (shouldPerformAnotherQuery && depth < std::numeric_limits<long long>::max() && (!_maxDepth || depth <= *_maxDepth)); - _frontier.clear(); + _frontier->clear(); _frontierUsageBytes = 0; } @@ -204,7 +204,7 @@ BSONObj addDepthFieldToObject(const std::string& field, long long depth, BSONObj bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long long depth) { Value _id = Value(result.getField("_id")); - if (_visited.find(_id) != _visited.end()) { + if (_visited->find(_id) != _visited->end()) { // We've already seen this object, don't repeat any work. return false; } @@ -215,7 +215,7 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon _depthField ? addDepthFieldToObject(_depthField->fullPath(), depth, result) : result; // Add the object to our '_visited' list. - _visited[_id] = fullObject; + (*_visited)[_id] = fullObject; // Update the size of '_visited' appropriately. _visitedUsageBytes += _id.getApproximateSize(); @@ -231,12 +231,12 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon Value recurseOn = Value(elem); if (recurseOn.isArray()) { for (auto&& subElem : recurseOn.getArray()) { - _frontier.insert(subElem); + _frontier->insert(subElem); _frontierUsageBytes += subElem.getApproximateSize(); } } else if (!recurseOn.missing()) { // Don't recurse on a missing value. - _frontier.insert(recurseOn); + _frontier->insert(recurseOn); _frontierUsageBytes += recurseOn.getApproximateSize(); } } @@ -246,7 +246,7 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon } void DocumentSourceGraphLookUp::addToCache(const BSONObj& result, - const unordered_set<Value, Value::Hash>& queried) { + const ValueUnorderedSet& queried) { BSONElementSet cacheByValues; dps::extractAllElementsAlongPath(result, _connectToField.fullPath(), cacheByValues); @@ -271,13 +271,13 @@ void DocumentSourceGraphLookUp::addToCache(const BSONObj& result, boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* cached) { // Add any cached values to 'cached' and remove them from '_frontier'. - for (auto it = _frontier.begin(); it != _frontier.end();) { + for (auto it = _frontier->begin(); it != _frontier->end();) { if (auto entry = _cache[*it]) { for (auto&& obj : *entry) { cached->insert(obj); } size_t valueSize = it->getApproximateSize(); - it = _frontier.erase(it); + it = _frontier->erase(it); // If the cached value increased in size while in the cache, we don't want to underflow // '_frontierUsageBytes'. @@ -302,7 +302,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* c BSONObjBuilder subObj(connectToObj.subobjStart(_connectToField.fullPath())); { BSONArrayBuilder in(subObj.subarrayStart("$in")); - for (auto&& value : _frontier) { + for (auto&& value : *_frontier) { in << value; } } @@ -310,7 +310,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* c } } - return _frontier.empty() ? boost::none : boost::optional<BSONObj>(query.obj()); + return _frontier->empty() ? boost::none : boost::optional<BSONObj>(query.obj()); } void DocumentSourceGraphLookUp::performSearch() { @@ -324,11 +324,11 @@ void DocumentSourceGraphLookUp::performSearch() { // If _startWith evaluates to an array, treat each value as a separate starting point. if (startingValue.isArray()) { for (auto value : startingValue.getArray()) { - _frontier.insert(value); + _frontier->insert(value); _frontierUsageBytes += value.getApproximateSize(); } } else { - _frontier.insert(startingValue); + _frontier->insert(startingValue); _frontierUsageBytes += startingValue.getApproximateSize(); } @@ -410,6 +410,11 @@ void DocumentSourceGraphLookUp::serializeToArray(std::vector<Value>& array, bool } } +void DocumentSourceGraphLookUp::doInjectExpressionContext() { + _frontier = pExpCtx->getValueComparator().makeUnorderedValueSet(); + _visited = pExpCtx->getValueComparator().makeUnorderedValueMap<BSONObj>(); +} + DocumentSourceGraphLookUp::DocumentSourceGraphLookUp( NamespaceString from, std::string as, diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp index a356986671e..ed7ab95e575 100644 --- a/src/mongo/db/pipeline/document_source_group.cpp +++ b/src/mongo/db/pipeline/document_source_group.cpp @@ -75,7 +75,7 @@ boost::optional<Document> DocumentSourceGroup::getNextSpilled() { _currentId = _firstPartOfNextGroup.first; const size_t numAccumulators = vpAccumulatorFactory.size(); - while (_currentId == _firstPartOfNextGroup.first) { + while (pExpCtx->getValueComparator().evaluate(_currentId == _firstPartOfNextGroup.first)) { // Inside of this loop, _firstPartOfNextGroup is the current data being processed. // At loop exit, it is the first value to be processed in the next group. switch (numAccumulators) { // mirrors switch in spill() @@ -104,12 +104,12 @@ boost::optional<Document> DocumentSourceGroup::getNextSpilled() { boost::optional<Document> DocumentSourceGroup::getNextStandard() { // Not spilled, and not streaming. - if (groups.empty()) + if (_groups->empty()) return boost::none; Document out = makeDocument(groupsIterator->first, groupsIterator->second, pExpCtx->inShard); - if (++groupsIterator == groups.end()) + if (++groupsIterator == _groups->end()) dispose(); return out; @@ -146,7 +146,7 @@ boost::optional<Document> DocumentSourceGroup::getNextStreaming() { // Compute the id. If it does not match _currentId, we will exit the loop, leaving // _firstDocOfNextGroup set for the next time getNext() is called. id = computeId(_variables.get()); - } while (_currentId == id); + } while (pExpCtx->getValueComparator().evaluate(_currentId == id)); Document out = makeDocument(_currentId, _currentAccumulators, pExpCtx->inShard); _currentId = std::move(id); @@ -156,11 +156,11 @@ boost::optional<Document> DocumentSourceGroup::getNextStreaming() { void DocumentSourceGroup::dispose() { // Free our resources. - GroupsMap().swap(groups); + GroupsMap().swap(*_groups); _sorterIterator.reset(); // Make us look done. - groupsIterator = groups.end(); + groupsIterator = _groups->end(); _firstDocOfNextGroup = boost::none; @@ -183,6 +183,23 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() { return this; } +void DocumentSourceGroup::doInjectExpressionContext() { + // Groups map must respect new comparator. + _groups = pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>(); + + for (auto&& idExpr : _idExpressions) { + idExpr->injectExpressionContext(pExpCtx); + } + + for (auto&& expr : vpExpression) { + expr->injectExpressionContext(pExpCtx); + } + + for (auto&& accum : _currentAccumulators) { + accum->injectExpressionContext(pExpCtx); + } +} + Value DocumentSourceGroup::serialize(bool explain) const { MutableDocument insides; @@ -237,7 +254,9 @@ DocumentSource::GetDepsReturn DocumentSourceGroup::getDependencies(DepsTracker* intrusive_ptr<DocumentSourceGroup> DocumentSourceGroup::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { - return new DocumentSourceGroup(pExpCtx); + intrusive_ptr<DocumentSourceGroup> source(new DocumentSourceGroup(pExpCtx)); + source->injectExpressionContext(pExpCtx); + return source; } DocumentSourceGroup::DocumentSourceGroup(const intrusive_ptr<ExpressionContext>& pExpCtx) @@ -435,6 +454,7 @@ void DocumentSourceGroup::initialize() { _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { _currentAccumulators.push_back(vpAccumulatorFactory[i]()); + _currentAccumulators.back()->injectExpressionContext(pExpCtx); } // We only need to load the first document. @@ -477,9 +497,9 @@ void DocumentSourceGroup::initialize() { Look for the _id value in the map; if it's not there, add a new entry with a blank accumulator. */ - const size_t oldSize = groups.size(); - vector<intrusive_ptr<Accumulator>>& group = groups[id]; - const bool inserted = groups.size() != oldSize; + const size_t oldSize = _groups->size(); + vector<intrusive_ptr<Accumulator>>& group = (*_groups)[id]; + const bool inserted = _groups->size() != oldSize; if (inserted) { memoryUsageBytes += id.getApproximateSize(); @@ -488,6 +508,7 @@ void DocumentSourceGroup::initialize() { group.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { group.push_back(vpAccumulatorFactory[i]()); + group.back()->injectExpressionContext(pExpCtx); } } else { for (size_t i = 0; i < numAccumulators; i++) { @@ -524,12 +545,12 @@ void DocumentSourceGroup::initialize() { // These blocks do any final steps necessary to prepare to output results. if (!sortedFiles.empty()) { _spilled = true; - if (!groups.empty()) { + if (!_groups->empty()) { sortedFiles.push_back(spill()); } // We won't be using groups again so free its memory. - GroupsMap().swap(groups); + GroupsMap().swap(*_groups); _sorterIterator.reset( Sorter<Value, Value>::Iterator::merge(sortedFiles, SortOptions(), SorterComparator())); @@ -538,20 +559,21 @@ void DocumentSourceGroup::initialize() { _currentAccumulators.reserve(numAccumulators); for (size_t i = 0; i < numAccumulators; i++) { _currentAccumulators.push_back(vpAccumulatorFactory[i]()); + _currentAccumulators.back()->injectExpressionContext(pExpCtx); } verify(_sorterIterator->more()); // we put data in, we should get something out. _firstPartOfNextGroup = _sorterIterator->next(); } else { // start the group iterator - groupsIterator = groups.begin(); + groupsIterator = _groups->begin(); } } shared_ptr<Sorter<Value, Value>::Iterator> DocumentSourceGroup::spill() { vector<const GroupsMap::value_type*> ptrs; // using pointers to speed sorting - ptrs.reserve(groups.size()); - for (GroupsMap::const_iterator it = groups.begin(), end = groups.end(); it != end; ++it) { + ptrs.reserve(_groups->size()); + for (GroupsMap::const_iterator it = _groups->begin(), end = _groups->end(); it != end; ++it) { ptrs.push_back(&*it); } @@ -583,7 +605,7 @@ shared_ptr<Sorter<Value, Value>::Iterator> DocumentSourceGroup::spill() { break; } - groups.clear(); + _groups->clear(); return shared_ptr<Sorter<Value, Value>::Iterator>(writer.done()); } @@ -761,7 +783,7 @@ void DocumentSourceGroup::parseIdExpression(BSONElement groupField, _idExpressions.push_back(ExpressionFieldPath::parse(groupField.str(), vps)); } else { // constant id - single group - _idExpressions.push_back(ExpressionConstant::create(Value(groupField))); + _idExpressions.push_back(ExpressionConstant::create(pExpCtx, Value(groupField))); } } diff --git a/src/mongo/db/pipeline/document_source_limit.cpp b/src/mongo/db/pipeline/document_source_limit.cpp index 401afc945ff..915a28ce8b6 100644 --- a/src/mongo/db/pipeline/document_source_limit.cpp +++ b/src/mongo/db/pipeline/document_source_limit.cpp @@ -81,7 +81,9 @@ Value DocumentSourceLimit::serialize(bool explain) const { intrusive_ptr<DocumentSourceLimit> DocumentSourceLimit::create( const intrusive_ptr<ExpressionContext>& pExpCtx, long long limit) { uassert(15958, "the limit must be positive", limit > 0); - return new DocumentSourceLimit(pExpCtx, limit); + intrusive_ptr<DocumentSourceLimit> source(new DocumentSourceLimit(pExpCtx, limit)); + source->injectExpressionContext(pExpCtx); + return source; } intrusive_ptr<DocumentSource> DocumentSourceLimit::createFromBson( diff --git a/src/mongo/db/pipeline/document_source_match.cpp b/src/mongo/db/pipeline/document_source_match.cpp index 6067acb3597..8035c237eae 100644 --- a/src/mongo/db/pipeline/document_source_match.cpp +++ b/src/mongo/db/pipeline/document_source_match.cpp @@ -357,9 +357,8 @@ bool DocumentSourceMatch::isTextQuery(const BSONObj& query) { void DocumentSourceMatch::joinMatchWith(intrusive_ptr<DocumentSourceMatch> other) { _predicate = BSON("$and" << BSON_ARRAY(_predicate << other->getQuery())); - // TODO SERVER-23349: Pass the appropriate CollatorInterface* instead of nullptr. StatusWithMatchExpression status = uassertStatusOK( - MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), nullptr)); + MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), pExpCtx->getCollator())); _expression = std::move(status.getValue()); } @@ -486,14 +485,18 @@ void DocumentSourceMatch::addDependencies(DepsTracker* deps) const { }); } +void DocumentSourceMatch::doInjectExpressionContext() { + _expression->setCollator(pExpCtx->getCollator()); +} + DocumentSourceMatch::DocumentSourceMatch(const BSONObj& query, const intrusive_ptr<ExpressionContext>& pExpCtx) : DocumentSource(pExpCtx), _predicate(query.getOwned()), _isTextQuery(isTextQuery(query)) { - // TODO SERVER-23349: Pass the appropriate CollatorInterface* instead of nullptr. StatusWithMatchExpression status = uassertStatusOK( - MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), nullptr)); + MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), pExpCtx->getCollator())); _expression = std::move(status.getValue()); getDependencies(&_dependencies); } + } // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_merge_cursors.cpp b/src/mongo/db/pipeline/document_source_merge_cursors.cpp index 01f11cb0c9f..8e7e3babbcd 100644 --- a/src/mongo/db/pipeline/document_source_merge_cursors.cpp +++ b/src/mongo/db/pipeline/document_source_merge_cursors.cpp @@ -52,7 +52,10 @@ const char* DocumentSourceMergeCursors::getSourceName() const { intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::create( std::vector<CursorDescriptor> cursorDescriptors, const intrusive_ptr<ExpressionContext>& pExpCtx) { - return new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx); + intrusive_ptr<DocumentSourceMergeCursors> source( + new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx)); + source->injectExpressionContext(pExpCtx); + return source; } intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::createFromBson( diff --git a/src/mongo/db/pipeline/document_source_project.cpp b/src/mongo/db/pipeline/document_source_project.cpp index f6b70b840bd..5ad0876a7b8 100644 --- a/src/mongo/db/pipeline/document_source_project.cpp +++ b/src/mongo/db/pipeline/document_source_project.cpp @@ -116,4 +116,9 @@ DocumentSource::GetDepsReturn DocumentSourceProject::getDependencies(DepsTracker return SEE_NEXT; } } + +void DocumentSourceProject::doInjectExpressionContext() { + _parsedProject->injectExpressionContext(pExpCtx); } + +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_source_redact.cpp b/src/mongo/db/pipeline/document_source_redact.cpp index 38f1e9fa8b8..3155037aa20 100644 --- a/src/mongo/db/pipeline/document_source_redact.cpp +++ b/src/mongo/db/pipeline/document_source_redact.cpp @@ -127,11 +127,12 @@ Value DocumentSourceRedact::redactValue(const Value& in) { boost::optional<Document> DocumentSourceRedact::redactObject() { const Value expressionResult = _expression->evaluate(_variables.get()); - if (expressionResult == keepVal) { + ValueComparator simpleValueCmp; + if (simpleValueCmp.evaluate(expressionResult == keepVal)) { return _variables->getDocument(_currentId); - } else if (expressionResult == pruneVal) { + } else if (simpleValueCmp.evaluate(expressionResult == pruneVal)) { return boost::optional<Document>(); - } else if (expressionResult == descendVal) { + } else if (simpleValueCmp.evaluate(expressionResult == descendVal)) { const Document in = _variables->getDocument(_currentId); MutableDocument out; out.copyMetaDataFrom(in); @@ -160,6 +161,10 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::optimize() { return this; } +void DocumentSourceRedact::doInjectExpressionContext() { + _expression->injectExpressionContext(pExpCtx); +} + Value DocumentSourceRedact::serialize(bool explain) const { return Value(DOC(getSourceName() << _expression.get()->serialize(explain))); } diff --git a/src/mongo/db/pipeline/document_source_sample.cpp b/src/mongo/db/pipeline/document_source_sample.cpp index 6a7f96f5ed5..121b6e04d36 100644 --- a/src/mongo/db/pipeline/document_source_sample.cpp +++ b/src/mongo/db/pipeline/document_source_sample.cpp @@ -73,6 +73,10 @@ Value DocumentSourceSample::serialize(bool explain) const { return Value(DOC(getSourceName() << DOC("size" << _size))); } +void DocumentSourceSample::doInjectExpressionContext() { + _sortStage->injectExpressionContext(pExpCtx); +} + namespace { const BSONObj randSortSpec = BSON("$rand" << BSON("$meta" << "randVal")); diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp index dafae4ed111..5a80a530734 100644 --- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp +++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp @@ -76,7 +76,7 @@ double smallestFromSampleOfUniform(PseudoRandom* prng, size_t N) { boost::optional<Document> DocumentSourceSampleFromRandomCursor::getNext() { pExpCtx->checkForInterrupt(); - if (_seenDocs.size() >= static_cast<size_t>(_size)) + if (_seenDocs->size() >= static_cast<size_t>(_size)) return {}; auto doc = getNextNonDuplicateDocument(); @@ -114,7 +114,7 @@ boost::optional<Document> DocumentSourceSampleFromRandomCursor::getNextNonDuplic << (*doc).toString(), !idField.missing()); - if (_seenDocs.insert(std::move(idField)).second) { + if (_seenDocs->insert(std::move(idField)).second) { return doc; } LOG(1) << "$sample encountered duplicate document: " << (*doc).toString() << std::endl; @@ -136,11 +136,18 @@ DocumentSource::GetDepsReturn DocumentSourceSampleFromRandomCursor::getDependenc return SEE_NEXT; } +void DocumentSourceSampleFromRandomCursor::doInjectExpressionContext() { + _seenDocs = pExpCtx->getValueComparator().makeUnorderedValueSet(); +} + intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRandomCursor::create( const intrusive_ptr<ExpressionContext>& expCtx, long long size, std::string idField, long long nDocsInCollection) { - return new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection); + intrusive_ptr<DocumentSourceSampleFromRandomCursor> source( + new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection)); + source->injectExpressionContext(expCtx); + return source; } } // mongo diff --git a/src/mongo/db/pipeline/document_source_skip.cpp b/src/mongo/db/pipeline/document_source_skip.cpp index 4b3d4ec168d..4957042f8a9 100644 --- a/src/mongo/db/pipeline/document_source_skip.cpp +++ b/src/mongo/db/pipeline/document_source_skip.cpp @@ -94,6 +94,7 @@ Pipeline::SourceContainer::iterator DocumentSourceSkip::optimizeAt( intrusive_ptr<DocumentSourceSkip> DocumentSourceSkip::create( const intrusive_ptr<ExpressionContext>& pExpCtx) { intrusive_ptr<DocumentSourceSkip> pSource(new DocumentSourceSkip(pExpCtx)); + pSource->injectExpressionContext(pExpCtx); return pSource; } diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp index a1456d7c70a..345383750cc 100644 --- a/src/mongo/db/pipeline/document_source_sort.cpp +++ b/src/mongo/db/pipeline/document_source_sort.cpp @@ -170,6 +170,7 @@ intrusive_ptr<DocumentSource> DocumentSourceSort::createFromBson( intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create( const intrusive_ptr<ExpressionContext>& pExpCtx, BSONObj sortOrder, long long limit) { intrusive_ptr<DocumentSourceSort> pSort = new DocumentSourceSort(pExpCtx); + pSort->injectExpressionContext(pExpCtx); pSort->_sort = sortOrder.getOwned(); /* check for then iterate over the sort object */ diff --git a/src/mongo/db/pipeline/document_source_test.cpp b/src/mongo/db/pipeline/document_source_test.cpp index b9c7210df33..cf13e1589b1 100644 --- a/src/mongo/db/pipeline/document_source_test.cpp +++ b/src/mongo/db/pipeline/document_source_test.cpp @@ -33,6 +33,7 @@ #include "mongo/db/operation_context_noop.h" #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/pipeline.h" #include "mongo/db/service_context.h" @@ -230,27 +231,27 @@ private: TEST(Mock, OneDoc) { auto doc = DOC("a" << 1); auto source = DocumentSourceMock::create(doc); - ASSERT_EQ(*source->getNext(), doc); + ASSERT_DOCUMENT_EQ(*source->getNext(), doc); ASSERT(!source->getNext()); } TEST(Mock, DequeDocuments) { auto source = DocumentSourceMock::create({DOC("a" << 1), DOC("a" << 2)}); - ASSERT_EQ(*source->getNext(), DOC("a" << 1)); - ASSERT_EQ(*source->getNext(), DOC("a" << 2)); + ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1)); + ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 2)); ASSERT(!source->getNext()); } TEST(Mock, StringJSON) { auto source = DocumentSourceMock::create("{a : 1}"); - ASSERT_EQ(*source->getNext(), DOC("a" << 1)); + ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1)); ASSERT(!source->getNext()); } TEST(Mock, DequeStringJSONs) { auto source = DocumentSourceMock::create({"{a: 1}", "{a: 2}"}); - ASSERT_EQ(*source->getNext(), DOC("a" << 1)); - ASSERT_EQ(*source->getNext(), DOC("a" << 2)); + ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1)); + ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 2)); ASSERT(!source->getNext()); } @@ -331,7 +332,7 @@ public: // The limit's result is as expected. boost::optional<Document> next = limit()->getNext(); ASSERT(bool(next)); - ASSERT_EQUALS(Value(1), next->getField("a")); + ASSERT_VALUE_EQ(Value(1), next->getField("a")); // The limit is exhausted. ASSERT(!limit()->getNext()); } @@ -373,7 +374,7 @@ public: // The limit is not exhauted. boost::optional<Document> next = limit()->getNext(); ASSERT(bool(next)); - ASSERT_EQUALS(Value(1), next->getField("a")); + ASSERT_VALUE_EQ(Value(1), next->getField("a")); // The limit is exhausted. ASSERT(!limit()->getNext()); } @@ -458,6 +459,7 @@ protected: expressionContext->tempDir = _tempDir.path(); _group = DocumentSourceGroup::createFromBson(specElement, expressionContext); + _group->injectExpressionContext(expressionContext); assertRoundTrips(_group); } DocumentSourceGroup* group() { @@ -1014,19 +1016,19 @@ public: auto res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id"), Value(0)); + ASSERT_VALUE_EQ(res->getField("_id"), Value(0)); ASSERT_TRUE(group()->isStreaming()); res = source->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("a"), Value(1)); + ASSERT_VALUE_EQ(res->getField("a"), Value(1)); assertExhausted(source); res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id"), Value(1)); + ASSERT_VALUE_EQ(res->getField("_id"), Value(1)); assertExhausted(group()); @@ -1049,20 +1051,20 @@ public: auto res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id")["x"], Value(1)); - ASSERT_EQUALS(res->getField("_id")["y"], Value(2)); + ASSERT_VALUE_EQ(res->getField("_id")["x"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["y"], Value(2)); ASSERT_TRUE(group()->isStreaming()); res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id")["x"], Value(1)); - ASSERT_EQUALS(res->getField("_id")["y"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["x"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["y"], Value(1)); res = source->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("a"), Value(2)); - ASSERT_EQUALS(res->getField("b"), Value(1)); + ASSERT_VALUE_EQ(res->getField("a"), Value(2)); + ASSERT_VALUE_EQ(res->getField("b"), Value(1)); assertExhausted(source); @@ -1089,13 +1091,13 @@ public: auto res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id")["x"]["y"]["z"], Value(3)); + ASSERT_VALUE_EQ(res->getField("_id")["x"]["y"]["z"], Value(3)); ASSERT_TRUE(group()->isStreaming()); res = source->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("a")["b"]["c"], Value(1)); + ASSERT_VALUE_EQ(res->getField("a")["b"]["c"], Value(1)); assertExhausted(source); @@ -1125,16 +1127,16 @@ public: auto res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id")["sub"]["x"], Value(1)); - ASSERT_EQUALS(res->getField("_id")["sub"]["y"], Value(1)); - ASSERT_EQUALS(res->getField("_id")["sub"]["z"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["x"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["y"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["z"], Value(1)); ASSERT_TRUE(group()->isStreaming()); res = source->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("a"), Value(2)); - ASSERT_EQUALS(res->getField("b"), Value(3)); + ASSERT_VALUE_EQ(res->getField("a"), Value(2)); + ASSERT_VALUE_EQ(res->getField("b"), Value(3)); BSONObjSet outputSort = group()->getOutputSorts(); @@ -1160,16 +1162,16 @@ public: auto res = group()->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("_id")["sub"]["x"], Value(5)); - ASSERT_EQUALS(res->getField("_id")["sub"]["y"], Value(1)); - ASSERT_EQUALS(res->getField("_id")["sub"]["z"], Value("c")); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["x"], Value(5)); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["y"], Value(1)); + ASSERT_VALUE_EQ(res->getField("_id")["sub"]["z"], Value("c")); ASSERT_TRUE(group()->isStreaming()); res = source->getNext(); ASSERT_TRUE(bool(res)); - ASSERT_EQUALS(res->getField("a"), Value(3)); - ASSERT_EQUALS(res->getField("b"), Value(1)); + ASSERT_VALUE_EQ(res->getField("a"), Value(3)); + ASSERT_VALUE_EQ(res->getField("b"), Value(1)); BSONObjSet outputSort = group()->getOutputSorts(); ASSERT_EQUALS(outputSort.size(), 2U); @@ -1903,7 +1905,7 @@ public: vector<Value> arr; sort()->serializeToArray(arr); - ASSERT_EQUALS( + ASSERT_VALUE_EQ( Value(arr), DOC_ARRAY(DOC("$sort" << DOC("a" << 1)) << DOC("$limit" << sort()->getLimit()))); @@ -3043,8 +3045,8 @@ using std::unique_ptr; // Helpers to make a DocumentSourceMatch from a query object or json string intrusive_ptr<DocumentSourceMatch> makeMatch(const BSONObj& query) { - intrusive_ptr<DocumentSource> uncasted = - DocumentSourceMatch::createFromBson(BSON("$match" << query).firstElement(), NULL); + intrusive_ptr<DocumentSource> uncasted = DocumentSourceMatch::createFromBson( + BSON("$match" << query).firstElement(), new ExpressionContext()); return dynamic_cast<DocumentSourceMatch*>(uncasted.get()); } intrusive_ptr<DocumentSourceMatch> makeMatch(const string& queryJson) { @@ -3458,11 +3460,11 @@ public: ASSERT_EQUALS(explainedStages.size(), 2UL); auto groupExplain = explainedStages[0]; - ASSERT_EQ(groupExplain["$group"], expectedGroupExplain); + ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain); auto sortExplain = explainedStages[1]; auto expectedSortExplain = Value{Document{{"sortKey", Document{{"count", -1}}}}}; - ASSERT_EQ(sortExplain["$sort"], expectedSortExplain); + ASSERT_VALUE_EQ(sortExplain["$sort"], expectedSortExplain); } }; @@ -3553,11 +3555,11 @@ public: Value{Document{{"_id", Document{{"$const", BSONNULL}}}, {countName, Document{{"$sum", Document{{"$const", 1}}}}}}}; auto groupExplain = explainedStages[0]; - ASSERT_EQ(groupExplain["$group"], expectedGroupExplain); + ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain); Value expectedProjectExplain = Value{Document{{"_id", false}, {countName, true}}}; auto projectExplain = explainedStages[1]; - ASSERT_EQ(projectExplain["$project"], expectedProjectExplain); + ASSERT_VALUE_EQ(projectExplain["$project"], expectedProjectExplain); } }; @@ -3644,12 +3646,12 @@ public: ASSERT_EQUALS(explainedStages.size(), 2UL); auto groupExplain = explainedStages[0]; - ASSERT_EQ(groupExplain["$group"], expectedGroupExplain); + ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain); auto sortExplain = explainedStages[1]; auto expectedSortExplain = Value{Document{{"sortKey", Document{{"_id", 1}}}}}; - ASSERT_EQ(sortExplain["$sort"], expectedSortExplain); + ASSERT_VALUE_EQ(sortExplain["$sort"], expectedSortExplain); } }; diff --git a/src/mongo/db/pipeline/document_source_unwind.cpp b/src/mongo/db/pipeline/document_source_unwind.cpp index b02723c1c94..82bb8f3adbe 100644 --- a/src/mongo/db/pipeline/document_source_unwind.cpp +++ b/src/mongo/db/pipeline/document_source_unwind.cpp @@ -174,11 +174,13 @@ intrusive_ptr<DocumentSourceUnwind> DocumentSourceUnwind::create( const string& unwindPath, bool preserveNullAndEmptyArrays, const boost::optional<string>& indexPath) { - return new DocumentSourceUnwind(expCtx, - FieldPath(unwindPath), - preserveNullAndEmptyArrays, - indexPath ? FieldPath(*indexPath) - : boost::optional<FieldPath>()); + intrusive_ptr<DocumentSourceUnwind> source( + new DocumentSourceUnwind(expCtx, + FieldPath(unwindPath), + preserveNullAndEmptyArrays, + indexPath ? FieldPath(*indexPath) : boost::optional<FieldPath>())); + source->injectExpressionContext(expCtx); + return source; } boost::optional<Document> DocumentSourceUnwind::getNext() { diff --git a/src/mongo/db/pipeline/document_value_test.cpp b/src/mongo/db/pipeline/document_value_test.cpp index 207039aa90c..b9e942f5659 100644 --- a/src/mongo/db/pipeline/document_value_test.cpp +++ b/src/mongo/db/pipeline/document_value_test.cpp @@ -31,6 +31,7 @@ #include "mongo/platform/basic.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/value.h" #include "mongo/dbtests/dbtests.h" @@ -67,7 +68,7 @@ void assertRoundTrips(const Document& document1) { Document document2 = fromBson(obj1); BSONObj obj2 = toBson(document2); ASSERT_EQUALS(obj1, obj2); - ASSERT_EQUALS(document1, document2); + ASSERT_DOCUMENT_EQ(document1, document2); } TEST(DocumentConstruction, Default) { @@ -185,26 +186,26 @@ public: md.remove("c"); ASSERT(md.peek().empty()); ASSERT_EQUALS(0U, md.peek().size()); - ASSERT_EQUALS(md.peek(), Document()); + ASSERT_DOCUMENT_EQ(md.peek(), Document()); ASSERT(!FieldIterator(md.peek()).more()); ASSERT(md.peek()["c"].missing()); assertRoundTrips(md.peek()); // Set a nested field using [] md["x"]["y"]["z"] = Value("nested"); - ASSERT_EQUALS(md.peek()["x"]["y"]["z"], Value("nested")); + ASSERT_VALUE_EQ(md.peek()["x"]["y"]["z"], Value("nested")); // Set a nested field using setNestedField FieldPath xxyyzz = string("xx.yy.zz"); md.setNestedField(xxyyzz, Value("nested")); - ASSERT_EQUALS(md.peek().getNestedField(xxyyzz), Value("nested")); + ASSERT_VALUE_EQ(md.peek().getNestedField(xxyyzz), Value("nested")); // Set a nested fields through an existing empty document md["xxx"] = Value(Document()); md["xxx"]["yyy"] = Value(Document()); FieldPath xxxyyyzzz = string("xxx.yyy.zzz"); md.setNestedField(xxxyyyzzz, Value("nested")); - ASSERT_EQUALS(md.peek().getNestedField(xxxyyyzzz), Value("nested")); + ASSERT_VALUE_EQ(md.peek().getNestedField(xxxyyyzzz), Value("nested")); // Make sure nothing moved ASSERT_EQUALS(apos, md.peek().positionOf("a")); @@ -269,7 +270,7 @@ public: MutableDocument cloneOnDemand(document); // Check equality. - ASSERT_EQUALS(document, cloneOnDemand.peek()); + ASSERT_DOCUMENT_EQ(document, cloneOnDemand.peek()); // Check pointer equality of sub document. ASSERT_EQUALS(document["a"].getDocument().getPtr(), cloneOnDemand.peek()["a"].getDocument().getPtr()); @@ -277,21 +278,21 @@ public: // Change field in clone and ensure the original document's field is unchanged. cloneOnDemand.setField(StringData("a"), Value(2)); - ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b"))); + ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b"))); // setNestedField and ensure the original document is unchanged. cloneOnDemand.reset(document); vector<Position> path; - ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b"), &path)); + ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b"), &path)); cloneOnDemand.setNestedField(path, Value(2)); - ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b"))); - ASSERT_EQUALS(Value(2), cloneOnDemand.peek().getNestedField(FieldPath("a.b"))); - ASSERT_EQUALS(DOC("a" << DOC("b" << 1)), document); - ASSERT_EQUALS(DOC("a" << DOC("b" << 2)), cloneOnDemand.freeze()); + ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b"))); + ASSERT_VALUE_EQ(Value(2), cloneOnDemand.peek().getNestedField(FieldPath("a.b"))); + ASSERT_DOCUMENT_EQ(DOC("a" << DOC("b" << 1)), document); + ASSERT_DOCUMENT_EQ(DOC("a" << DOC("b" << 2)), cloneOnDemand.freeze()); } }; @@ -301,7 +302,7 @@ public: void run() { Document document = fromBson(fromjson("{a:1,b:['ra',4],c:{z:1},d:'lal'}")); Document clonedDocument = document.clone(); - ASSERT_EQUALS(document, clonedDocument); + ASSERT_DOCUMENT_EQ(document, clonedDocument); } }; @@ -404,7 +405,7 @@ public: // logical equality ASSERT_EQUALS(obj, obj2); - ASSERT_EQUALS(doc, doc2); + ASSERT_DOCUMENT_EQ(doc, doc2); // binary equality ASSERT_EQUALS(obj.objsize(), obj2.objsize()); @@ -482,7 +483,7 @@ protected: void assertRoundTrips(const Document& input) { // Round trip to/from a buffer. auto output = roundTrip(input); - ASSERT_EQ(output, input); + ASSERT_DOCUMENT_EQ(output, input); ASSERT_EQ(output.hasTextScore(), input.hasTextScore()); ASSERT_EQ(output.hasRandMetaField(), input.hasRandMetaField()); if (input.hasTextScore()) @@ -564,15 +565,15 @@ void assertRoundTrips(const Value& value1) { Value value2 = fromBson(obj1); BSONObj obj2 = toBson(value2); ASSERT_EQUALS(obj1, obj2); - ASSERT_EQUALS(value1, value2); + ASSERT_VALUE_EQ(value1, value2); ASSERT_EQUALS(value1.getType(), value2.getType()); } class BSONArrayTest { public: void run() { - ASSERT_EQUALS(Value(BSON_ARRAY(1 << 2 << 3)), DOC_ARRAY(1 << 2 << 3)); - ASSERT_EQUALS(Value(BSONArray()), Value(vector<Value>())); + ASSERT_VALUE_EQ(Value(BSON_ARRAY(1 << 2 << 3)), DOC_ARRAY(1 << 2 << 3)); + ASSERT_VALUE_EQ(Value(BSONArray()), Value(vector<Value>())); } }; @@ -1661,10 +1662,10 @@ public: arrayOfMissing.serializeForSorter(bb); BufReader reader(bb.buf(), bb.len()); - ASSERT_EQUALS(missing, - Value::deserializeForSorter(reader, Value::SorterDeserializeSettings())); - ASSERT_EQUALS(arrayOfMissing, - Value::deserializeForSorter(reader, Value::SorterDeserializeSettings())); + ASSERT_VALUE_EQ(missing, + Value::deserializeForSorter(reader, Value::SorterDeserializeSettings())); + ASSERT_VALUE_EQ(arrayOfMissing, + Value::deserializeForSorter(reader, Value::SorterDeserializeSettings())); } }; } // namespace Value diff --git a/src/mongo/db/pipeline/document_value_test_util.cpp b/src/mongo/db/pipeline/document_value_test_util.cpp new file mode 100644 index 00000000000..223ea4c8f56 --- /dev/null +++ b/src/mongo/db/pipeline/document_value_test_util.cpp @@ -0,0 +1,67 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/document_value_test_util.h" + +namespace mongo { +namespace unittest { + +#define _GENERATE_DOCVAL_CMP_FUNC(DOCVAL, NAME, COMPARATOR, OPERATOR) \ + void assertComparison_##DOCVAL##NAME(const std::string& theFile, \ + unsigned theLine, \ + StringData aExpression, \ + StringData bExpression, \ + const DOCVAL& aValue, \ + const DOCVAL& bValue) { \ + if (!COMPARATOR().evaluate(aValue OPERATOR bValue)) { \ + std::ostringstream os; \ + os << "Expected [ " << aExpression << " " #OPERATOR " " << bExpression \ + << " ] but found [ " << aValue << " " #OPERATOR " " << bValue << "]"; \ + TestAssertionFailure(theFile, theLine, os.str()).stream(); \ + } \ + } + +_GENERATE_DOCVAL_CMP_FUNC(Value, EQ, ValueComparator, ==); +_GENERATE_DOCVAL_CMP_FUNC(Value, LT, ValueComparator, <); +_GENERATE_DOCVAL_CMP_FUNC(Value, LTE, ValueComparator, <=); +_GENERATE_DOCVAL_CMP_FUNC(Value, GT, ValueComparator, >); +_GENERATE_DOCVAL_CMP_FUNC(Value, GTE, ValueComparator, >=); +_GENERATE_DOCVAL_CMP_FUNC(Value, NE, ValueComparator, !=); + +_GENERATE_DOCVAL_CMP_FUNC(Document, EQ, DocumentComparator, ==); +_GENERATE_DOCVAL_CMP_FUNC(Document, LT, DocumentComparator, <); +_GENERATE_DOCVAL_CMP_FUNC(Document, LTE, DocumentComparator, <=); +_GENERATE_DOCVAL_CMP_FUNC(Document, GT, DocumentComparator, >); +_GENERATE_DOCVAL_CMP_FUNC(Document, GTE, DocumentComparator, >=); +_GENERATE_DOCVAL_CMP_FUNC(Document, NE, DocumentComparator, !=); +#undef _GENERATE_DOCVAL_CMP_FUNC + +} // namespace unittest +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_value_test_util.h b/src/mongo/db/pipeline/document_value_test_util.h new file mode 100644 index 00000000000..03b26e696af --- /dev/null +++ b/src/mongo/db/pipeline/document_value_test_util.h @@ -0,0 +1,88 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/document_comparator.h" +#include "mongo/db/pipeline/value_comparator.h" +#include "mongo/unittest/unittest.h" + +/** + * Use to compare two instances of type Value under the default ValueComparator in unit tests. + */ +#define ASSERT_VALUE_EQ(a, b) _ASSERT_DOCVAL_COMPARISON(ValueEQ, a, b) +#define ASSERT_VALUE_LT(a, b) _ASSERT_DOCVAL_COMPARISON(ValueLT, a, b) +#define ASSERT_VALUE_LTE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueLTE, a, b) +#define ASSERT_VALUE_GT(a, b) _ASSERT_DOCVAL_COMPARISON(ValueGT, a, b) +#define ASSERT_VALUE_GTE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueGTE, a, b) +#define ASSERT_VALUE_NE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueNE, a, b) + +/** + * Use to compare two instances of type Document under the default DocumentComparator in unit tests. + */ +#define ASSERT_DOCUMENT_EQ(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentEQ, a, b) +#define ASSERT_DOCUMENT_LT(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentLT, a, b) +#define ASSERT_DOCUMENT_LTE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentLTE, a, b) +#define ASSERT_DOCUMENT_GT(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentGT, a, b) +#define ASSERT_DOCUMENT_GTE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentGTE, a, b) +#define ASSERT_DOCUMENT_NE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentNE, a, b) + +/** + * Document/Value comparison utility macro. Do not use directly. + */ +#define _ASSERT_DOCVAL_COMPARISON(NAME, a, b) \ + ::mongo::unittest::assertComparison_##NAME(__FILE__, __LINE__, #a, #b, a, b) + +namespace mongo { +namespace unittest { + +#define _DECLARE_DOCVAL_CMP_FUNC(DOCVAL, NAME) \ + void assertComparison_##DOCVAL##NAME(const std::string& theFile, \ + unsigned theLine, \ + StringData aExpression, \ + StringData bExpression, \ + const DOCVAL& aValue, \ + const DOCVAL& bValue); + +_DECLARE_DOCVAL_CMP_FUNC(Value, EQ); +_DECLARE_DOCVAL_CMP_FUNC(Value, LT); +_DECLARE_DOCVAL_CMP_FUNC(Value, LTE); +_DECLARE_DOCVAL_CMP_FUNC(Value, GT); +_DECLARE_DOCVAL_CMP_FUNC(Value, GTE); +_DECLARE_DOCVAL_CMP_FUNC(Value, NE); + +_DECLARE_DOCVAL_CMP_FUNC(Document, EQ); +_DECLARE_DOCVAL_CMP_FUNC(Document, LT); +_DECLARE_DOCVAL_CMP_FUNC(Document, LTE); +_DECLARE_DOCVAL_CMP_FUNC(Document, GT); +_DECLARE_DOCVAL_CMP_FUNC(Document, GTE); +_DECLARE_DOCVAL_CMP_FUNC(Document, NE); +#undef _DECLARE_DOCVAL_CMP_FUNC + +} // namespace unittest +} // namespace mongo diff --git a/src/mongo/db/pipeline/document_value_test_util_self_test.cpp b/src/mongo/db/pipeline/document_value_test_util_self_test.cpp new file mode 100644 index 00000000000..cbe55346a29 --- /dev/null +++ b/src/mongo/db/pipeline/document_value_test_util_self_test.cpp @@ -0,0 +1,92 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/document_value_test_util.h" + +#include "mongo/db/query/collation/collator_interface_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +TEST(DocumentValueTestUtilSelfTest, DocumentEQ) { + ASSERT_DOCUMENT_EQ(Document({{"foo", "bar"}}), Document({{"foo", "bar"}})); +} + +TEST(DocumentValueTestUtilSelfTest, DocumentNE) { + ASSERT_DOCUMENT_NE(Document({{"foo", "bar"}}), Document({{"foo", "baz"}})); +} + +TEST(DocumentValueTestUtilSelfTest, DocumentLT) { + ASSERT_DOCUMENT_LT(Document({{"foo", "bar"}}), Document({{"foo", "baz"}})); +} + +TEST(DocumentValueTestUtilSelfTest, DocumentLTE) { + ASSERT_DOCUMENT_LTE(Document({{"foo", "bar"}}), Document({{"foo", "baz"}})); + ASSERT_DOCUMENT_LTE(Document({{"foo", "bar"}}), Document({{"foo", "bar"}})); +} + +TEST(DocumentValueTestUtilSelfTest, DocumentGT) { + ASSERT_DOCUMENT_GT(Document({{"foo", "baz"}}), Document({{"foo", "bar"}})); +} + +TEST(DocumentValueTestUtilSelfTest, DocumentGTE) { + ASSERT_DOCUMENT_GTE(Document({{"foo", "baz"}}), Document({{"foo", "bar"}})); + ASSERT_DOCUMENT_GTE(Document({{"foo", "bar"}}), Document({{"foo", "bar"}})); +} + +TEST(DocumentValueTestUtilSelfTest, ValueEQ) { + ASSERT_VALUE_EQ(Value("bar"), Value("bar")); +} + +TEST(DocumentValueTestUtilSelfTest, ValueNE) { + ASSERT_VALUE_NE(Value("bar"), Value("baz")); +} + +TEST(DocumentValueTestUtilSelfTest, ValueLT) { + ASSERT_VALUE_LT(Value("bar"), Value("baz")); +} + +TEST(DocumentValueTestUtilSelfTest, ValueLTE) { + ASSERT_VALUE_LTE(Value("bar"), Value("baz")); + ASSERT_VALUE_LTE(Value("bar"), Value("bar")); +} + +TEST(DocumentValueTestUtilSelfTest, ValueGT) { + ASSERT_VALUE_GT(Value("baz"), Value("bar")); +} + +TEST(DocumentValueTestUtilSelfTest, ValueGTE) { + ASSERT_VALUE_GTE(Value("baz"), Value("bar")); + ASSERT_VALUE_GTE(Value("bar"), Value("bar")); +} + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index b8f81dd6329..40affa27bbe 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -456,7 +456,8 @@ intrusive_ptr<Expression> ExpressionAnd::optimize() { */ bool last = pConst->getValue().coerceToBool(); if (!last) { - intrusive_ptr<ExpressionConstant> pFinal(ExpressionConstant::create(Value(false))); + intrusive_ptr<ExpressionConstant> pFinal( + ExpressionConstant::create(getExpressionContext(), Value(false))); return pFinal; } @@ -467,7 +468,8 @@ intrusive_ptr<Expression> ExpressionAnd::optimize() { the result will be a boolean. */ if (n == 2) { - intrusive_ptr<Expression> pFinal(ExpressionCoerceToBool::create(pAnd->vpOperand[0])); + intrusive_ptr<Expression> pFinal( + ExpressionCoerceToBool::create(getExpressionContext(), pAnd->vpOperand[0])); return pFinal; } @@ -612,8 +614,9 @@ const char* ExpressionCeil::getOpName() const { /* -------------------- ExpressionCoerceToBool ------------------------- */ intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create( - const intrusive_ptr<Expression>& pExpression) { + const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) { intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(pExpression)); + pNew->injectExpressionContext(expCtx); return pNew; } @@ -653,6 +656,11 @@ Value ExpressionCoerceToBool::serialize(bool explain) const { return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain)))); } +void ExpressionCoerceToBool::doInjectExpressionContext() { + // Inject our ExpressionContext into the operand. + pExpression->injectExpressionContext(getExpressionContext()); +} + /* ----------------------- ExpressionCompare --------------------------- */ REGISTER_EXPRESSION(cmp, @@ -854,8 +862,10 @@ intrusive_ptr<Expression> ExpressionConstant::parse(BSONElement exprElement, } -intrusive_ptr<ExpressionConstant> ExpressionConstant::create(const Value& pValue) { +intrusive_ptr<ExpressionConstant> ExpressionConstant::create( + const intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue) { intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(pValue)); + pEC->injectExpressionContext(expCtx); return pEC; } @@ -1084,6 +1094,10 @@ void ExpressionDateToString::addDependencies(DepsTracker* deps) const { _date->addDependencies(deps); } +void ExpressionDateToString::doInjectExpressionContext() { + _date->injectExpressionContext(getExpressionContext()); +} + /* ---------------------- ExpressionDayOfMonth ------------------------- */ Value ExpressionDayOfMonth::evaluateInternal(Variables* vars) const { @@ -1235,6 +1249,12 @@ Value ExpressionObject::serialize(bool explain) const { return outputDoc.freezeToValue(); } +void ExpressionObject::doInjectExpressionContext() { + for (auto&& pair : _expressions) { + pair.second->injectExpressionContext(getExpressionContext()); + } +} + /* --------------------- ExpressionFieldPath --------------------------- */ // this is the old deprecated version only used by tests not using variables @@ -1454,6 +1474,11 @@ void ExpressionFilter::addDependencies(DepsTracker* deps) const { _filter->addDependencies(deps); } +void ExpressionFilter::doInjectExpressionContext() { + _input->injectExpressionContext(getExpressionContext()); + _filter->injectExpressionContext(getExpressionContext()); +} + /* ------------------------- ExpressionFloor -------------------------- */ Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const { @@ -1569,6 +1594,10 @@ void ExpressionLet::addDependencies(DepsTracker* deps) const { _subExpression->addDependencies(deps); } +void ExpressionLet::doInjectExpressionContext() { + _subExpression->injectExpressionContext(getExpressionContext()); +} + /* ------------------------- ExpressionMap ----------------------------- */ @@ -1670,6 +1699,11 @@ void ExpressionMap::addDependencies(DepsTracker* deps) const { _each->addDependencies(deps); } +void ExpressionMap::doInjectExpressionContext() { + _input->injectExpressionContext(getExpressionContext()); + _each->injectExpressionContext(getExpressionContext()); +} + /* ------------------------- ExpressionMeta ----------------------------- */ REGISTER_EXPRESSION(meta, ExpressionMeta::parse); @@ -1917,7 +1951,7 @@ Value ExpressionIn::evaluateInternal(Variables* vars) const { << typeName(arrayOfValues.getType()), arrayOfValues.isArray()); for (auto&& value : arrayOfValues.getArray()) { - if (argument == value) { + if (getExpressionContext()->getValueComparator().evaluate(argument == value)) { return Value(true); } } @@ -1984,7 +2018,7 @@ Value ExpressionIndexOfArray::evaluateInternal(Variables* vars) const { } for (size_t i = startIndex; i < endIndex; i++) { - if (array[i] == searchItem) { + if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) { return Value(static_cast<int>(i)); } } @@ -2260,7 +2294,8 @@ intrusive_ptr<Expression> ExpressionNary::optimize() { // expression. if (constOperandCount == vpOperand.size()) { Variables emptyVars; - return intrusive_ptr<Expression>(ExpressionConstant::create(evaluateInternal(&emptyVars))); + return intrusive_ptr<Expression>( + ExpressionConstant::create(getExpressionContext(), evaluateInternal(&emptyVars))); } // If the expression is associative, we can collapse all the consecutive constant operands into @@ -2304,8 +2339,8 @@ intrusive_ptr<Expression> ExpressionNary::optimize() { ExpressionVector vpOperandSave = std::move(vpOperand); vpOperand = std::move(constExpressions); Variables emptyVars; - optimizedOperands.emplace_back( - ExpressionConstant::create(evaluateInternal(&emptyVars))); + optimizedOperands.emplace_back(ExpressionConstant::create( + getExpressionContext(), evaluateInternal(&emptyVars))); vpOperand = std::move(vpOperandSave); } else { optimizedOperands.insert( @@ -2321,7 +2356,7 @@ intrusive_ptr<Expression> ExpressionNary::optimize() { vpOperand = std::move(constExpressions); Variables emptyVars; optimizedOperands.emplace_back( - ExpressionConstant::create(evaluateInternal(&emptyVars))); + ExpressionConstant::create(getExpressionContext(), evaluateInternal(&emptyVars))); } else { optimizedOperands.insert( optimizedOperands.end(), constExpressions.begin(), constExpressions.end()); @@ -2352,6 +2387,12 @@ Value ExpressionNary::serialize(bool explain) const { return Value(DOC(getOpName() << array)); } +void ExpressionNary::doInjectExpressionContext() { + for (auto&& operand : vpOperand) { + operand->injectExpressionContext(getExpressionContext()); + } +} + /* ------------------------- ExpressionNot ----------------------------- */ Value ExpressionNot::evaluateInternal(Variables* vars) const { @@ -2407,7 +2448,8 @@ intrusive_ptr<Expression> ExpressionOr::optimize() { */ bool last = pConst->getValue().coerceToBool(); if (last) { - intrusive_ptr<ExpressionConstant> pFinal(ExpressionConstant::create(Value(true))); + intrusive_ptr<ExpressionConstant> pFinal( + ExpressionConstant::create(getExpressionContext(), Value(true))); return pFinal; } @@ -2418,7 +2460,8 @@ intrusive_ptr<Expression> ExpressionOr::optimize() { the result will be a boolean. */ if (n == 2) { - intrusive_ptr<Expression> pFinal(ExpressionCoerceToBool::create(pOr->vpOperand[0])); + intrusive_ptr<Expression> pFinal( + ExpressionCoerceToBool::create(getExpressionContext(), pOr->vpOperand[0])); return pFinal; } @@ -2738,6 +2781,12 @@ Value ExpressionReduce::serialize(bool explain) const { {"in", _in->serialize(explain)}}}}); } +void ExpressionReduce::doInjectExpressionContext() { + _input->injectExpressionContext(getExpressionContext()); + _initial->injectExpressionContext(getExpressionContext()); + _in->injectExpressionContext(getExpressionContext()); +} + /* ------------------------ ExpressionReverseArray ------------------------ */ Value ExpressionReverseArray::evaluateInternal(Variables* vars) const { @@ -2779,9 +2828,11 @@ const char* ExpressionSecond::getOpName() const { } namespace { -ValueSet arrayToSet(const Value& val) { +ValueSet arrayToSet(const Value& val, const ValueComparator& valueComparator) { const vector<Value>& array = val.getArray(); - return ValueSet(array.begin(), array.end()); + ValueSet valueSet = valueComparator.makeOrderedValueSet(); + valueSet.insert(array.begin(), array.end()); + return valueSet; } } @@ -2806,7 +2857,7 @@ Value ExpressionSetDifference::evaluateInternal(Variables* vars) const { << typeName(rhs.getType()), rhs.isArray()); - ValueSet rhsSet = arrayToSet(rhs); + ValueSet rhsSet = arrayToSet(rhs, getExpressionContext()->getValueComparator()); const vector<Value>& lhsArray = lhs.getArray(); vector<Value> returnVec; @@ -2835,7 +2886,8 @@ void ExpressionSetEquals::validateArguments(const ExpressionVector& args) const Value ExpressionSetEquals::evaluateInternal(Variables* vars) const { const size_t n = vpOperand.size(); - std::set<Value> lhs; + const auto& valueComparator = getExpressionContext()->getValueComparator(); + ValueSet lhs = valueComparator.makeOrderedValueSet(); for (size_t i = 0; i < n; i++) { const Value nextEntry = vpOperand[i]->evaluateInternal(vars); @@ -2848,8 +2900,13 @@ Value ExpressionSetEquals::evaluateInternal(Variables* vars) const { if (i == 0) { lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end()); } else { - const std::set<Value> rhs(nextEntry.getArray().begin(), nextEntry.getArray().end()); - if (lhs != rhs) { + ValueSet rhs = valueComparator.makeOrderedValueSet(); + rhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end()); + if (lhs.size() != rhs.size()) { + return Value(false); + } + + if (!std::equal(lhs.begin(), lhs.end(), rhs.begin(), valueComparator.getEqualTo())) { return Value(false); } } @@ -2866,7 +2923,8 @@ const char* ExpressionSetEquals::getOpName() const { Value ExpressionSetIntersection::evaluateInternal(Variables* vars) const { const size_t n = vpOperand.size(); - ValueSet currentIntersection; + const auto& valueComparator = getExpressionContext()->getValueComparator(); + ValueSet currentIntersection = valueComparator.makeOrderedValueSet(); for (size_t i = 0; i < n; i++) { const Value nextEntry = vpOperand[i]->evaluateInternal(vars); if (nextEntry.nullish()) { @@ -2881,7 +2939,7 @@ Value ExpressionSetIntersection::evaluateInternal(Variables* vars) const { if (i == 0) { currentIntersection.insert(nextEntry.getArray().begin(), nextEntry.getArray().end()); } else { - ValueSet nextSet = arrayToSet(nextEntry); + ValueSet nextSet = arrayToSet(nextEntry, valueComparator); if (currentIntersection.size() > nextSet.size()) { // to iterate over whichever is the smaller set nextSet.swap(currentIntersection); @@ -2939,7 +2997,8 @@ Value ExpressionSetIsSubset::evaluateInternal(Variables* vars) const { << typeName(rhs.getType()), rhs.isArray()); - return setIsSubsetHelper(lhs.getArray(), arrayToSet(rhs)); + return setIsSubsetHelper(lhs.getArray(), + arrayToSet(rhs, getExpressionContext()->getValueComparator())); } /** @@ -2988,7 +3047,10 @@ intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() { << typeName(rhs.getType()), rhs.isArray()); - return new Optimized(arrayToSet(rhs), vpOperand); + intrusive_ptr<Expression> optimizedWithConstant(new Optimized( + arrayToSet(rhs, getExpressionContext()->getValueComparator()), vpOperand)); + optimizedWithConstant->injectExpressionContext(getExpressionContext()); + return optimizedWithConstant; } return optimized; } @@ -3001,7 +3063,7 @@ const char* ExpressionSetIsSubset::getOpName() const { /* ----------------------- ExpressionSetUnion ---------------------------- */ Value ExpressionSetUnion::evaluateInternal(Variables* vars) const { - ValueSet unionedSet; + ValueSet unionedSet = getExpressionContext()->getValueComparator().makeOrderedValueSet(); const size_t n = vpOperand.size(); for (size_t i = 0; i < n; i++) { const Value newEntries = vpOperand[i]->evaluateInternal(vars); @@ -3606,6 +3668,17 @@ Value ExpressionSwitch::serialize(bool explain) const { return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}}); } +void ExpressionSwitch::doInjectExpressionContext() { + if (_default) { + _default->injectExpressionContext(getExpressionContext()); + } + + for (auto&& pair : _branches) { + pair.first->injectExpressionContext(getExpressionContext()); + pair.second->injectExpressionContext(getExpressionContext()); + } +} + /* ------------------------- ExpressionToLower ----------------------------- */ Value ExpressionToLower::evaluateInternal(Variables* vars) const { @@ -4032,6 +4105,16 @@ void ExpressionZip::addDependencies(DepsTracker* deps) const { }); } +void ExpressionZip::doInjectExpressionContext() { + for (auto&& expr : _inputs) { + expr->injectExpressionContext(getExpressionContext()); + } + + for (auto&& expr : _defaults) { + expr->injectExpressionContext(getExpressionContext()); + } +} + const char* ExpressionZip::getOpName() const { return "$zip"; } diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index c4ac4b07704..4a849675300 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -38,6 +38,7 @@ #include "mongo/base/init.h" #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/value.h" #include "mongo/stdx/functional.h" @@ -281,8 +282,31 @@ public: */ static void registerExpression(std::string key, Parser parser); + /** + * Injects the ExpressionContext so that it may be used during evaluation of the Expression. + * Construction of expressions is done at parse time, but the ExpressionContext isn't finalized + * until later, at which point it is injected using this method. + */ + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { + _expCtx = expCtx; + doInjectExpressionContext(); + } + protected: typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector; + + /** + * Expressions which need to update their internal state when attaching to a new + * ExpressionContext should override this method. + */ + virtual void doInjectExpressionContext() {} + + const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const { + return _expCtx; + } + +private: + boost::intrusive_ptr<ExpressionContext> _expCtx; }; @@ -322,6 +346,11 @@ public: static ExpressionVector parseArguments(BSONElement bsonExpr, const VariablesParseState& vps); + // TODO SERVER-23349: Currently there are subclasses which derive from this base class that + // require custom logic for expression context injection. Consider making those classes inherit + // directly from Expression so that this method can be marked 'final' rather than 'override'. + void doInjectExpressionContext() override; + protected: ExpressionNary() {} @@ -538,8 +567,11 @@ public: Value serialize(bool explain) const final; static boost::intrusive_ptr<ExpressionCoerceToBool> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const boost::intrusive_ptr<Expression>& pExpression); + void doInjectExpressionContext() final; + private: explicit ExpressionCoerceToBool(const boost::intrusive_ptr<Expression>& pExpression); @@ -619,7 +651,9 @@ public: const char* getOpName() const; - static boost::intrusive_ptr<ExpressionConstant> create(const Value& pValue); + static boost::intrusive_ptr<ExpressionConstant> create( + const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue); + static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr, const VariablesParseState& vps); @@ -647,6 +681,8 @@ public: static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + void doInjectExpressionContext() final; + private: ExpressionDateToString(const std::string& format, // the format string boost::intrusive_ptr<Expression> date); // the date to format @@ -777,6 +813,8 @@ public: static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + void doInjectExpressionContext() final; + private: ExpressionFilter(std::string varName, Variables::Id varId, @@ -859,6 +897,8 @@ public: static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + void doInjectExpressionContext() final; + struct NameAndExpression { NameAndExpression() {} NameAndExpression(std::string name, boost::intrusive_ptr<Expression> expression) @@ -901,6 +941,8 @@ public: static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps); + void doInjectExpressionContext() final; + private: ExpressionMap( const std::string& varName, // name of variable to set @@ -1026,6 +1068,8 @@ public: return _expressions; } + void doInjectExpressionContext() final; + private: ExpressionObject( std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions); @@ -1072,6 +1116,8 @@ public: const VariablesParseState& vpsIn); Value serialize(bool explain) const final; + void doInjectExpressionContext() final; + private: boost::intrusive_ptr<Expression> _input; boost::intrusive_ptr<Expression> _initial; @@ -1242,6 +1288,8 @@ public: Value serialize(bool explain) const final; const char* getOpName() const final; + void doInjectExpressionContext() final; + private: using ExpressionPair = std::pair<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>>; @@ -1337,6 +1385,8 @@ public: Value serialize(bool explain) const final; const char* getOpName() const final; + void doInjectExpressionContext() final; + private: bool _useLongestLength = false; ExpressionVector _inputs; diff --git a/src/mongo/db/pipeline/expression_context.cpp b/src/mongo/db/pipeline/expression_context.cpp index d7c5a6708a0..5992b8aa708 100644 --- a/src/mongo/db/pipeline/expression_context.cpp +++ b/src/mongo/db/pipeline/expression_context.cpp @@ -45,7 +45,7 @@ ExpressionContext::ExpressionContext(OperationContext* opCtx, const AggregationR auto statusWithCollator = CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation); uassertStatusOK(statusWithCollator.getStatus()); - collator = std::move(statusWithCollator.getValue()); + setCollator(std::move(statusWithCollator.getValue())); } } @@ -56,4 +56,13 @@ void ExpressionContext::checkForInterrupt() { interruptCounter = kInterruptCheckPeriod; } } + +void ExpressionContext::setCollator(std::unique_ptr<CollatorInterface> coll) { + _collator = std::move(coll); + + // Document/Value comparisons must be aware of the collation. + _documentComparator = DocumentComparator(_collator.get()); + _valueComparator = ValueComparator(_collator.get()); +} + } // namespace mongo diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h index 90d5bad3a90..285318f7f7a 100644 --- a/src/mongo/db/pipeline/expression_context.h +++ b/src/mongo/db/pipeline/expression_context.h @@ -35,6 +35,8 @@ #include "mongo/db/namespace_string.h" #include "mongo/db/operation_context.h" #include "mongo/db/pipeline/aggregation_request.h" +#include "mongo/db/pipeline/document_comparator.h" +#include "mongo/db/pipeline/value_comparator.h" #include "mongo/db/query/collation/collator_interface.h" #include "mongo/util/intrusive_counter.h" @@ -42,6 +44,8 @@ namespace mongo { struct ExpressionContext : public IntrusiveCounterUnsigned { public: + ExpressionContext() = default; + ExpressionContext(OperationContext* opCtx, const AggregationRequest& request); /** @@ -50,11 +54,25 @@ public: */ void checkForInterrupt(); - bool isExplain; - bool inShard; + void setCollator(std::unique_ptr<CollatorInterface> coll); + + const CollatorInterface* getCollator() const { + return _collator.get(); + } + + const DocumentComparator& getDocumentComparator() const { + return _documentComparator; + } + + const ValueComparator& getValueComparator() const { + return _valueComparator; + } + + bool isExplain = false; + bool inShard = false; bool inRouter = false; - bool extSortAllowed; - bool bypassDocumentValidation; + bool extSortAllowed = false; + bool bypassDocumentValidation = false; NamespaceString ns; std::string tempDir; // Defaults to empty to prevent external sorting in mongos. @@ -65,11 +83,17 @@ public: // collation. const BSONObj collation; + static const int kInterruptCheckPeriod = 128; + int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus + +private: // Collator used to compare elements. 'collator' is initialized from 'collation', except in the // case where 'collation' is empty and there is a collection default collation. - std::unique_ptr<CollatorInterface> collator; + std::unique_ptr<CollatorInterface> _collator; - static const int kInterruptCheckPeriod = 128; - int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus + // Used for all comparisons of Document/Value during execution of the aggregation operation. + DocumentComparator _documentComparator; + ValueComparator _valueComparator; }; -} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp index 09c4685df4f..9828e719aab 100644 --- a/src/mongo/db/pipeline/expression_test.cpp +++ b/src/mongo/db/pipeline/expression_test.cpp @@ -32,7 +32,10 @@ #include "mongo/config.h" #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/expression_context.h" +#include "mongo/db/pipeline/value_comparator.h" #include "mongo/dbtests/dbtests.h" #include "mongo/unittest/unittest.h" @@ -57,11 +60,14 @@ static void assertExpectedResults(string expression, initializer_list<pair<vector<Value>, Value>> operations) { for (auto&& op : operations) { try { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const BSONObj obj = BSON(expression << Value(op.first)); - Value result = Expression::parseExpression(obj, vps)->evaluate(Document()); - ASSERT_EQUALS(op.second, result); + auto expression = Expression::parseExpression(obj, vps); + expression->injectExpressionContext(expCtx); + Value result = expression->evaluate(Document()); + ASSERT_VALUE_EQ(op.second, result); ASSERT_EQUALS(op.second.getType(), result.getType()); } catch (...) { log() << "failed with arguments: " << Value(op.first); @@ -127,13 +133,13 @@ Value valueFromBson(BSONObj obj) { template <typename T> intrusive_ptr<ExpressionConstant> makeConstant(T&& val) { - return ExpressionConstant::create(Value(std::forward<T>(val))); + return ExpressionConstant::create(nullptr, Value(std::forward<T>(val))); } class ExpressionBaseTest : public unittest::Test { public: void addOperand(intrusive_ptr<ExpressionNary> expr, Value arg) { - expr->addOperand(ExpressionConstant::create(arg)); + expr->addOperand(ExpressionConstant::create(nullptr, arg)); } }; @@ -141,7 +147,7 @@ class ExpressionNaryTestOneArg : public ExpressionBaseTest { public: virtual void assertEvaluates(Value input, Value output) { addOperand(_expr, input); - ASSERT_EQUALS(output, _expr->evaluate(Document())); + ASSERT_VALUE_EQ(output, _expr->evaluate(Document())); ASSERT_EQUALS(output.getType(), _expr->evaluate(Document()).getType()); } @@ -231,7 +237,7 @@ protected: }; TEST_F(ExpressionNaryTest, AddedConstantOperandIsSerialized) { - _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(9))); + _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(9))); assertContents(_notAssociativeNorCommutative, BSON_ARRAY(9)); } @@ -245,7 +251,7 @@ TEST_F(ExpressionNaryTest, ValidateEmptyDependencies) { } TEST_F(ExpressionNaryTest, ValidateConstantExpressionDependency) { - _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(1))); + _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(1))); assertDependencies(_notAssociativeNorCommutative, BSONArray()); } @@ -269,13 +275,13 @@ TEST_F(ExpressionNaryTest, ValidateObjectExpressionDependency) { } TEST_F(ExpressionNaryTest, SerializationToBsonObj) { - _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(5))); + _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(5))); ASSERT_EQUALS(BSON("foo" << BSON("$testable" << BSON_ARRAY(BSON("$const" << 5)))), BSON("foo" << _notAssociativeNorCommutative->serialize(false))); } TEST_F(ExpressionNaryTest, SerializationToBsonArr) { - _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(5))); + _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(5))); ASSERT_EQUALS(constify(BSON_ARRAY(BSON("$testable" << BSON_ARRAY(5)))), BSON_ARRAY(_notAssociativeNorCommutative->serialize(false))); } @@ -925,7 +931,7 @@ class NullDocument { public: void run() { intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); - expression->addOperand(ExpressionConstant::create(Value(2))); + expression->addOperand(ExpressionConstant::create(nullptr, Value(2))); ASSERT_EQUALS(BSON("" << 2), toBson(expression->evaluate(Document()))); } }; @@ -943,7 +949,7 @@ class String { public: void run() { intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); - expression->addOperand(ExpressionConstant::create(Value("a"))); + expression->addOperand(ExpressionConstant::create(nullptr, Value("a"))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } }; @@ -953,14 +959,14 @@ class Bool { public: void run() { intrusive_ptr<ExpressionNary> expression = new ExpressionAdd(); - expression->addOperand(ExpressionConstant::create(Value(true))); + expression->addOperand(ExpressionConstant::create(nullptr, Value(true))); ASSERT_THROWS(expression->evaluate(Document()), UserException); } }; class SingleOperandBase : public ExpectedResultBase { void populateOperands(intrusive_ptr<ExpressionNary>& expression) { - expression->addOperand(ExpressionConstant::create(valueFromBson(operand()))); + expression->addOperand(ExpressionConstant::create(nullptr, valueFromBson(operand()))); } BSONObj expectedResult() { return operand(); @@ -1031,9 +1037,9 @@ public: protected: void populateOperands(intrusive_ptr<ExpressionNary>& expression) { expression->addOperand( - ExpressionConstant::create(valueFromBson(_reverse ? operand2() : operand1()))); + ExpressionConstant::create(nullptr, valueFromBson(_reverse ? operand2() : operand1()))); expression->addOperand( - ExpressionConstant::create(valueFromBson(_reverse ? operand1() : operand2()))); + ExpressionConstant::create(nullptr, valueFromBson(_reverse ? operand1() : operand2()))); } virtual BSONObj operand1() = 0; virtual BSONObj operand2() = 0; @@ -1185,11 +1191,13 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec()), expressionToBson(expression)); ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(fromBson(BSON("a" << 1))))); @@ -1207,11 +1215,13 @@ class OptimizeBase { public: virtual ~OptimizeBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObject = BSON("" << spec()); BSONElement specElement = specObject.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec()), expressionToBson(expression)); intrusive_ptr<Expression> optimized = expression->optimize(); ASSERT_EQUALS(expectedOptimized(), expressionToBson(optimized)); @@ -1483,8 +1493,8 @@ namespace CoerceToBool { class EvaluateTrue { public: void run() { - intrusive_ptr<Expression> nested = ExpressionConstant::create(Value(5)); - intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested); + intrusive_ptr<Expression> nested = ExpressionConstant::create(nullptr, Value(5)); + intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested); ASSERT(expression->evaluate(Document()).getBool()); } }; @@ -1493,8 +1503,8 @@ public: class EvaluateFalse { public: void run() { - intrusive_ptr<Expression> nested = ExpressionConstant::create(Value(0)); - intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested); + intrusive_ptr<Expression> nested = ExpressionConstant::create(nullptr, Value(0)); + intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested); ASSERT(!expression->evaluate(Document()).getBool()); } }; @@ -1504,7 +1514,7 @@ class Dependencies { public: void run() { intrusive_ptr<Expression> nested = ExpressionFieldPath::create("a.b"); - intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested); + intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested); DepsTracker dependencies; expression->addDependencies(&dependencies); ASSERT_EQUALS(1U, dependencies.fields.size()); @@ -1519,7 +1529,7 @@ class AddToBsonObj { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(fromjson("{field:{$and:['$foo']}}"), toBsonObj(expression)); @@ -1536,7 +1546,7 @@ class AddToBsonArray { public: void run() { intrusive_ptr<Expression> expression = - ExpressionCoerceToBool::create(ExpressionFieldPath::create("foo")); + ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo")); // serialized as $and because CoerceToBool isn't an ExpressionNary assertBinaryEqual(BSON_ARRAY(fromjson("{$and:['$foo']}")), toBsonArray(expression)); @@ -1840,7 +1850,7 @@ public: VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); - ASSERT_EQUALS(expression->evaluate(Document()), Value(true)); + ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true)); } }; @@ -1971,7 +1981,7 @@ namespace Constant { class Create { public: void run() { - intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5)); + intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5)); assertBinaryEqual(BSON("" << 5), toBson(expression->evaluate(Document()))); } }; @@ -1996,7 +2006,7 @@ public: class Optimize { public: void run() { - intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5)); + intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5)); // An attempt to optimize returns the Expression itself. ASSERT_EQUALS(expression, expression->optimize()); } @@ -2006,7 +2016,7 @@ public: class Dependencies { public: void run() { - intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5)); + intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5)); DepsTracker dependencies; expression->addDependencies(&dependencies); ASSERT_EQUALS(0U, dependencies.fields.size()); @@ -2019,7 +2029,7 @@ public: class AddToBsonObj { public: void run() { - intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5)); + intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5)); // The constant is replaced with a $ expression. assertBinaryEqual(BSON("field" << BSON("$const" << 5)), toBsonObj(expression)); } @@ -2034,7 +2044,7 @@ private: class AddToBsonArray { public: void run() { - intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5)); + intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5)); // The constant is copied out as is. assertBinaryEqual(constify(BSON_ARRAY(5)), toBsonArray(expression)); } @@ -2342,7 +2352,7 @@ TEST(ExpressionObjectParse, ShouldAcceptEmptyObject) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); auto object = ExpressionObject::parse(BSONObj(), vps); - ASSERT_EQUALS(Value(Document{}), object->serialize(false)); + ASSERT_VALUE_EQ(Value(Document{}), object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { @@ -2355,7 +2365,7 @@ TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) { vps); auto expectedResult = Value(Document{{"a", literal(5)}, {"b", literal("string")}, {"c", literal(BSONNULL)}}); - ASSERT_EQUALS(expectedResult, object->serialize(false)); + ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) { @@ -2363,7 +2373,7 @@ TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) { VariablesParseState vps(&idGen); auto object = ExpressionObject::parse(BSON("_id" << 5), vps); auto expectedResult = Value(Document{{"_id", literal(5)}}); - ASSERT_EQUALS(expectedResult, object->serialize(false)); + ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) { @@ -2371,7 +2381,7 @@ TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) { VariablesParseState vps(&idGen); auto object = ExpressionObject::parse(BSON("a$b" << 5), vps); auto expectedResult = Value(Document{{"a$b", literal(5)}}); - ASSERT_EQUALS(expectedResult, object->serialize(false)); + ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { @@ -2381,7 +2391,7 @@ TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) { auto expectedResult = Value(Document{{"a", Document{{"b", literal(1)}}}, {"c", Document{{"d", Document{{"e", literal(1)}, {"f", literal(1)}}}}}}); - ASSERT_EQUALS(expectedResult, object->serialize(false)); + ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ExpressionObjectParse, ShouldAcceptArrays) { @@ -2390,14 +2400,15 @@ TEST(ExpressionObjectParse, ShouldAcceptArrays) { auto object = ExpressionObject::parse(fromjson("{a: [1, 2]}"), vps); auto expectedResult = Value(Document{{"a", vector<Value>{Value(literal(1)), Value(literal(2))}}}); - ASSERT_EQUALS(expectedResult, object->serialize(false)); + ASSERT_VALUE_EQ(expectedResult, object->serialize(false)); } TEST(ObjectParsing, ShouldAcceptExpressionAsValue) { VariablesIdGenerator idGen; VariablesParseState vps(&idGen); auto object = ExpressionObject::parse(BSON("a" << BSON("$and" << BSONArray())), vps); - ASSERT_EQ(object->serialize(false), Value(Document{{"a", Document{{"$and", BSONArray()}}}})); + ASSERT_VALUE_EQ(object->serialize(false), + Value(Document{{"a", Document{{"$and", BSONArray()}}}})); } // @@ -2455,31 +2466,31 @@ TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) { TEST(ExpressionObjectEvaluate, EmptyObjectShouldEvaluateToEmptyDocument) { auto object = ExpressionObject::create({}); - ASSERT_EQUALS(Value(Document()), object->evaluate(Document())); - ASSERT_EQUALS(Value(Document()), object->evaluate(Document{{"a", 1}})); - ASSERT_EQUALS(Value(Document()), object->evaluate(Document{{"_id", "ID"}})); + ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"a", 1}})); + ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"_id", "ID"}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) { auto object = ExpressionObject::create({{"a", makeConstant(1)}, {"b", makeConstant(5)}}); - ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document())); - ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}})); - ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"_id", "ID"}})); + ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}})); + ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"_id", "ID"}})); } TEST(ExpressionObjectEvaluate, OrderOfFieldsInOutputShouldMatchOrderInSpecification) { auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("a")}, {"b", ExpressionFieldPath::create("b")}, {"c", ExpressionFieldPath::create("c")}}); - ASSERT_EQUALS(Value(Document{{"a", "A"}, {"b", "B"}, {"c", "C"}}), - object->evaluate(Document{{"c", "C"}, {"a", "A"}, {"b", "B"}, {"_id", "ID"}})); + ASSERT_VALUE_EQ(Value(Document{{"a", "A"}, {"b", "B"}, {"c", "C"}}), + object->evaluate(Document{{"c", "C"}, {"a", "A"}, {"b", "B"}, {"_id", "ID"}})); } TEST(ExpressionObjectEvaluate, ShouldRemoveFieldsThatHaveMissingValues) { auto object = ExpressionObject::create( {{"a", ExpressionFieldPath::create("a.b")}, {"b", ExpressionFieldPath::create("missing")}}); - ASSERT_EQUALS(Value(Document{}), object->evaluate(Document())); - ASSERT_EQUALS(Value(Document{}), object->evaluate(Document{{"a", 1}})); + ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document{{"a", 1}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) { @@ -2487,18 +2498,18 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) { {{"a", ExpressionObject::create( {{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create("_id")}})}}); - ASSERT_EQUALS(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document())); - ASSERT_EQUALS(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"}}}}), - object->evaluate(Document{{"_id", "ID"}})); + ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"}}}}), + object->evaluate(Document{{"_id", "ID"}})); } TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissing) { auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("missing")}}); - ASSERT_EQUALS(Value(Document{}), object->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document())); auto objectWithNestedObject = ExpressionObject::create({{"nested", object}}); - ASSERT_EQUALS(Value(Document{{"nested", Document{}}}), - objectWithNestedObject->evaluate(Document())); + ASSERT_VALUE_EQ(Value(Document{{"nested", Document{}}}), + objectWithNestedObject->evaluate(Document())); } // @@ -2542,7 +2553,7 @@ TEST(ExpressionObjectOptimizations, OptimizingAnObjectShouldOptimizeSubExpressio auto expConstant = dynamic_cast<ExpressionConstant*>(optimizedObject->getChildExpressions()[0].second.get()); ASSERT_TRUE(expConstant); - ASSERT_EQ(expConstant->evaluate(Document()), Value(3)); + ASSERT_VALUE_EQ(expConstant->evaluate(Document()), Value(3)); }; } // namespace Object @@ -2897,7 +2908,7 @@ TEST(ParseExpression, ShouldRecognizeConstExpression) { auto resultExpression = parseExpression(BSON("$const" << 5)); auto constExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get()); ASSERT_TRUE(constExpression); - ASSERT_EQUALS(constExpression->serialize(false), Value(Document{{"$const", 5}})); + ASSERT_VALUE_EQ(constExpression->serialize(false), Value(Document{{"$const", 5}})); } TEST(ParseExpression, ShouldRejectUnknownExpression) { @@ -2931,15 +2942,15 @@ TEST(ParseExpression, ShouldParseExpressionWithMultipleArguments) { ASSERT_TRUE(strCaseCmpExpression); vector<Value> arguments = {Value(Document{{"$const", "foo"}}), Value(Document{{"$const", "FOO"}})}; - ASSERT_EQUALS(strCaseCmpExpression->serialize(false), - Value(Document{{"$strcasecmp", arguments}})); + ASSERT_VALUE_EQ(strCaseCmpExpression->serialize(false), + Value(Document{{"$strcasecmp", arguments}})); } TEST(ParseExpression, ShouldParseExpressionWithNoArguments) { auto resultExpression = parseExpression(BSON("$and" << BSONArray())); auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); - ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}})); } TEST(ParseExpression, ShouldParseExpressionWithOneArgument) { @@ -2947,7 +2958,7 @@ TEST(ParseExpression, ShouldParseExpressionWithOneArgument) { auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); vector<Value> arguments = {Value(Document{{"$const", 1}})}; - ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}})); } TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayForVariadicExpressions) { @@ -2955,7 +2966,7 @@ TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayForVariadicExpressions) { auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); vector<Value> arguments = {Value(Document{{"$const", 1}})}; - ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}})); } TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayAsSingleArgument) { @@ -2963,7 +2974,7 @@ TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayAsSingleArgument) { auto notExpression = dynamic_cast<ExpressionNot*>(resultExpression.get()); ASSERT_TRUE(notExpression); vector<Value> arguments = {Value(Document{{"$const", 1}})}; - ASSERT_EQUALS(notExpression->serialize(false), Value(Document{{"$not", arguments}})); + ASSERT_VALUE_EQ(notExpression->serialize(false), Value(Document{{"$not", arguments}})); } TEST(ParseExpression, ShouldAcceptObjectAsSingleArgument) { @@ -2971,7 +2982,7 @@ TEST(ParseExpression, ShouldAcceptObjectAsSingleArgument) { auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); vector<Value> arguments = {Value(Document{{"$const", 1}})}; - ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}})); } TEST(ParseExpression, ShouldAcceptObjectInsideArrayAsSingleArgument) { @@ -2979,7 +2990,7 @@ TEST(ParseExpression, ShouldAcceptObjectInsideArrayAsSingleArgument) { auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); vector<Value> arguments = {Value(Document{{"$const", 1}})}; - ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}})); } } // namespace Expression @@ -3005,7 +3016,7 @@ TEST(ParseOperand, ShouldRecognizeFieldPath) { << "$field")); auto fieldPathExpression = dynamic_cast<ExpressionFieldPath*>(resultExpression.get()); ASSERT_TRUE(fieldPathExpression); - ASSERT_EQ(fieldPathExpression->serialize(false), Value("$field")); + ASSERT_VALUE_EQ(fieldPathExpression->serialize(false), Value("$field")); } TEST(ParseOperand, ShouldRecognizeStringLiteral) { @@ -3013,7 +3024,7 @@ TEST(ParseOperand, ShouldRecognizeStringLiteral) { << "foo")); auto constantExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get()); ASSERT_TRUE(constantExpression); - ASSERT_EQ(constantExpression->serialize(false), Value(Document{{"$const", "foo"}})); + ASSERT_VALUE_EQ(constantExpression->serialize(false), Value(Document{{"$const", "foo"}})); } TEST(ParseOperand, ShouldRecognizeNestedArray) { @@ -3022,21 +3033,21 @@ TEST(ParseOperand, ShouldRecognizeNestedArray) { auto arrayExpression = dynamic_cast<ExpressionArray*>(resultExpression.get()); ASSERT_TRUE(arrayExpression); vector<Value> expectedSerializedArray = {Value(Document{{"$const", "foo"}}), Value("$field")}; - ASSERT_EQ(arrayExpression->serialize(false), Value(expectedSerializedArray)); + ASSERT_VALUE_EQ(arrayExpression->serialize(false), Value(expectedSerializedArray)); } TEST(ParseOperand, ShouldRecognizeNumberLiteral) { auto resultExpression = parseOperand(BSON("" << 5)); auto constantExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get()); ASSERT_TRUE(constantExpression); - ASSERT_EQ(constantExpression->serialize(false), Value(Document{{"$const", 5}})); + ASSERT_VALUE_EQ(constantExpression->serialize(false), Value(Document{{"$const", 5}})); } TEST(ParseOperand, ShouldRecognizeNestedExpression) { auto resultExpression = parseOperand(BSON("" << BSON("$and" << BSONArray()))); auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get()); ASSERT_TRUE(andExpression); - ASSERT_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}})); + ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}})); } } // namespace Operand @@ -3049,7 +3060,8 @@ Value sortSet(Value set) { return Value(BSONNULL); } vector<Value> sortedSet = set.getArray(); - sort(sortedSet.begin(), sortedSet.end()); + ValueComparator valueComparator; + sort(sortedSet.begin(), sortedSet.end(), valueComparator.getLessThan()); return Value(sortedSet); } @@ -3057,6 +3069,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3068,11 +3081,12 @@ public: VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); + expr->injectExpressionContext(expCtx); Value result = expr->evaluate(Document()); if (result.getType() == Array) { result = sortSet(result); } - if (result != expected) { + if (ValueComparator().evaluate(result != expected)) { string errMsg = str::stream() << "for expression " << field.first.toString() << " with argument " << args.toString() << " full tree: " << expr->serialize(false).toString() @@ -3096,6 +3110,7 @@ public: // same const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); + expr->injectExpressionContext(expCtx); expr->evaluate(Document()); }, UserException); @@ -3372,11 +3387,13 @@ private: return BSON("$strcasecmp" << BSON_ARRAY(b() << a())); } void assertResult(int expectedResult, const BSONObj& spec) { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObj = BSON("" << spec); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec), expressionToBson(expression)); ASSERT_EQUALS(BSON("" << expectedResult), toBson(expression->evaluate(Document()))); } @@ -3498,11 +3515,13 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec()), expressionToBson(expression)); ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3753,11 +3772,13 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec()), expressionToBson(expression)); ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3810,11 +3831,13 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); BSONObj specObj = BSON("" << spec()); BSONElement specElement = specObj.firstElement(); VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps); + expression->injectExpressionContext(expCtx); ASSERT_EQUALS(constify(spec()), expressionToBson(expression)); ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document()))); } @@ -3866,6 +3889,7 @@ class ExpectedResultBase { public: virtual ~ExpectedResultBase() {} void run() { + intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext()); const Document spec = getSpec(); const Value args = spec["input"]; if (!spec["expected"].missing()) { @@ -3877,8 +3901,9 @@ public: VariablesIdGenerator idGenerator; VariablesParseState vps(&idGenerator); const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); + expr->injectExpressionContext(expCtx); const Value result = expr->evaluate(Document()); - if (result != expected) { + if (ValueComparator().evaluate(result != expected)) { string errMsg = str::stream() << "for expression " << field.first.toString() << " with argument " << args.toString() << " full tree: " << expr->serialize(false).toString() @@ -3902,6 +3927,7 @@ public: // same const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps); + expr->injectExpressionContext(expCtx); expr->evaluate(Document()); }, UserException); diff --git a/src/mongo/db/pipeline/lookup_set_cache.h b/src/mongo/db/pipeline/lookup_set_cache.h index 3150d7bc1af..7ca74f90964 100644 --- a/src/mongo/db/pipeline/lookup_set_cache.h +++ b/src/mongo/db/pipeline/lookup_set_cache.h @@ -41,6 +41,7 @@ #include "mongo/bson/bsonobj.h" #include "mongo/db/pipeline/value.h" +#include "mongo/db/pipeline/value_comparator.h" #include "mongo/stdx/functional.h" namespace mongo { @@ -56,6 +57,9 @@ using boost::multi_index::indexed_by; * limit, but includes the ability to evict down to both a specific number of elements, and down to * a specific amount of memory. Memory usage includes only the size of the elements in the cache at * the time of insertion, not the overhead incurred by the data structures in use. + * + * TODO SERVER-23349: This class must make all comparisons of user data using the aggregation + * operation's collation. */ class LookupSetCache { public: @@ -167,9 +171,12 @@ private: // a container of std::pair<Value, BSONObjSet>, that is both sequenced, and has a unique // index on the Value. From this, we are able to evict the least-recently-used member, and // maintain key uniqueness. - using IndexedContainer = multi_index_container< - Cached, - indexed_by<sequenced<>, hashed_unique<member<Cached, Value, &Cached::first>, Value::Hash>>>; + using IndexedContainer = + multi_index_container<Cached, + indexed_by<sequenced<>, + hashed_unique<member<Cached, Value, &Cached::first>, + Value::Hash, + ValueComparator::EqualTo>>>; IndexedContainer _container; diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.h b/src/mongo/db/pipeline/parsed_aggregation_projection.h index 9149ff7a53c..a3420fdb954 100644 --- a/src/mongo/db/pipeline/parsed_aggregation_projection.h +++ b/src/mongo/db/pipeline/parsed_aggregation_projection.h @@ -28,13 +28,15 @@ #pragma once +#include <boost/intrusive_ptr.hpp> #include <memory> namespace mongo { class BSONObj; -struct DepsTracker; class Document; +struct DepsTracker; +struct ExpressionContext; namespace parsed_aggregation_projection { @@ -80,6 +82,11 @@ public: virtual void optimize() {} /** + * Inject the ExpressionContext into any expressions contained within this projection. + */ + virtual void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {} + + /** * Add any dependencies needed by this projection or any sub-expressions to 'deps'. */ virtual void addDependencies(DepsTracker* deps) const {} diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp index 7703ac3960b..5768b46882a 100644 --- a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp @@ -39,6 +39,7 @@ #include "mongo/bson/json.h" #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/death_test.h" #include "mongo/unittest/unittest.h" @@ -83,17 +84,17 @@ TEST(ExclusionProjection, ShouldSerializeToEquivalentProjection) { // fields is subject to change. auto serialization = exclusion.serialize(); ASSERT_EQ(serialization.size(), 4UL); - ASSERT_EQ(serialization["a"], Value(false)); - ASSERT_EQ(serialization["_id"], Value(false)); + ASSERT_VALUE_EQ(serialization["a"], Value(false)); + ASSERT_VALUE_EQ(serialization["_id"], Value(false)); ASSERT_EQ(serialization["b"].getType(), BSONType::Object); ASSERT_EQ(serialization["b"].getDocument().size(), 2UL); - ASSERT_EQ(serialization["b"].getDocument()["c"], Value(false)); - ASSERT_EQ(serialization["b"].getDocument()["d"], Value(false)); + ASSERT_VALUE_EQ(serialization["b"].getDocument()["c"], Value(false)); + ASSERT_VALUE_EQ(serialization["b"].getDocument()["d"], Value(false)); ASSERT_EQ(serialization["x"].getType(), BSONType::Object); ASSERT_EQ(serialization["x"].getDocument().size(), 1UL); - ASSERT_EQ(serialization["x"].getDocument()["y"], Value(false)); + ASSERT_VALUE_EQ(serialization["x"].getDocument()["y"], Value(false)); } TEST(ExclusionProjection, ShouldNotAddAnyDependencies) { @@ -127,22 +128,22 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) { // More than one field in document. auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); auto expectedResult = Document{{"b", 2}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is the only field in the document. result = exclusion.applyProjection(Document{{"a", 1}}); expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is not present in the document. result = exclusion.applyProjection(Document{{"c", 1}}); expectedResult = Document{{"c", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // There are no fields in the document. result = exclusion.applyProjection(Document{}); expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { @@ -152,7 +153,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) { auto result = exclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}, {"c", 3}}); auto expectedResult = Document{{"_id", "ID"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { @@ -160,7 +161,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) { exclusion.parse(BSON("second" << false)); auto result = exclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"first", 0}, {"third", 2}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { @@ -168,7 +169,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { exclusion.parse(BSON("a" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}}); auto expectedResult = Document{{"b", 2}, {"_id", "ID"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { @@ -176,7 +177,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { exclusion.parse(BSON("a" << false << "_id" << false)); auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}}); auto expectedResult = Document{{"b", 2}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } // @@ -189,7 +190,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) { auto result = exclusion.applyProjection( Document{{"_id", Document{{"x", 1}, {"y", 2}, {"z", 3}}}, {"a", 1}}); auto expectedResult = Document{{"_id", Document{{"z", 3}}}, {"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) { @@ -199,22 +200,22 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) // More than one field in sub document. auto result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); auto expectedResult = Document{{"a", Document{{"c", 2}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is the only field in the sub document. result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}}}}); expectedResult = Document{{"a", Document{}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is not present in the sub document. result = exclusion.applyProjection(Document{{"a", Document{{"c", 1}}}}); expectedResult = Document{{"a", Document{{"c", 1}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // There are no fields in sub document. result = exclusion.applyProjection(Document{{"a", Document{}}}); expectedResult = Document{{"a", Document{}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFieldDoesNotExist) { @@ -224,12 +225,12 @@ TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFiel // Should not add the path if it doesn't exist. auto result = exclusion.applyProjection(Document{}); auto expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Should not replace non-documents with documents. result = exclusion.applyProjection(Document{{"sub", "notADocument"}}); expectedResult = Document{{"sub", "notADocument"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementInArray) { @@ -252,7 +253,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementIn Value(vector<Value>{Value(1), Value(Document{{"c", 1}})})}; auto result = exclusion.applyProjection(Document{{"a", nestedValues}}); auto expectedResult = Document{{"a", expectedNestedValues}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { @@ -263,7 +264,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { auto result = exclusion.applyProjection( Document{{"a", Document{{"b", 1}, {"c", 2}, {"d", 3}, {"e", 4}, {"f", 5}}}}); auto expectedResult = Document{{"a", Document{{"f", 5}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) { @@ -279,7 +280,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) MutableDocument expectedDoc(Document{{"_id", "ID"}}); expectedDoc.copyMetaDataFrom(inputDoc); - ASSERT_EQ(result, expectedDoc.freeze()); + ASSERT_DOCUMENT_EQ(result, expectedDoc.freeze()); } } // namespace diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp index 69d600f5551..4842500cd90 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp @@ -54,6 +54,16 @@ void InclusionNode::optimize() { } } +void InclusionNode::injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) { + for (auto&& expressionIt : _expressions) { + expressionIt.second->injectExpressionContext(expCtx); + } + + for (auto&& childPair : _children) { + childPair.second->injectExpressionContext(expCtx); + } +} + void InclusionNode::serialize(MutableDocument* output, bool explain) const { // Always put "_id" first if it was included (implicitly or explicitly). if (_inclusions.find("_id") != _inclusions.end()) { diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.h b/src/mongo/db/pipeline/parsed_inclusion_projection.h index 7236d8f85eb..d762d6d7ede 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection.h +++ b/src/mongo/db/pipeline/parsed_inclusion_projection.h @@ -33,6 +33,7 @@ #include <unordered_set> #include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/parsed_aggregation_projection.h" #include "mongo/stdx/memory.h" @@ -118,6 +119,8 @@ public: return _pathToNode; } + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); + private: // Helpers for the Document versions above. These will apply the transformation recursively to // each element of any arrays, and ensure non-documents are handled appropriately. @@ -202,6 +205,10 @@ public: _root->optimize(); } + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final { + _root->injectExpressionContext(expCtx); + } + void addDependencies(DepsTracker* deps) const final { _root->addDependencies(deps); } diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp index f4c1d6a4a0e..b2833638543 100644 --- a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp +++ b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp @@ -37,6 +37,7 @@ #include "mongo/bson/json.h" #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/value.h" #include "mongo/unittest/unittest.h" @@ -129,8 +130,8 @@ TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) { "{_id: true, a: {$add: [\"$a\", {$const: 2}]}, b: {d: true}, x: {y: {$const: 4}}}")); // Should be the same if we're serializing for explain or for internal use. - ASSERT_EQ(expectedSerialization, inclusion.serialize(false)); - ASSERT_EQ(expectedSerialization, inclusion.serialize(true)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true)); } TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { @@ -141,8 +142,8 @@ TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) { auto expectedSerialization = Document{{"_id", false}, {"a", true}}; // Should be the same if we're serializing for explain or for internal use. - ASSERT_EQ(expectedSerialization, inclusion.serialize(false)); - ASSERT_EQ(expectedSerialization, inclusion.serialize(true)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true)); } @@ -155,8 +156,8 @@ TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) { auto expectedSerialization = Document{{"_id", true}, {"a", Document{{"$const", 3}}}}; // Should be the same if we're serializing for explain or for internal use. - ASSERT_EQ(expectedSerialization, inclusion.serialize(false)); - ASSERT_EQ(expectedSerialization, inclusion.serialize(true)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true)); } TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { @@ -169,8 +170,8 @@ TEST(InclusionProjection, ShouldOptimizeNestedExpressions) { Document{{"_id", true}, {"a", Document{{"b", Document{{"$const", 3}}}}}}; // Should be the same if we're serializing for explain or for internal use. - ASSERT_EQ(expectedSerialization, inclusion.serialize(false)); - ASSERT_EQ(expectedSerialization, inclusion.serialize(true)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false)); + ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true)); } // @@ -184,22 +185,22 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) { // More than one field in document. auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}}); auto expectedResult = Document{{"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is the only field in the document. result = inclusion.applyProjection(Document{{"a", 1}}); expectedResult = Document{{"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is not present in the document. result = inclusion.applyProjection(Document{{"c", 1}}); expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // There are no fields in the document. result = inclusion.applyProjection(Document{}); expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { @@ -207,12 +208,12 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) { inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"newField", "computedVal"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Computed field should replace existing field. result = inclusion.applyProjection(Document{{"newField", "preExisting"}}); expectedResult = Document{{"newField", "computedVal"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedFields) { @@ -220,7 +221,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedField inclusion.parse(BSON("a" << true << "newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"newField", "computedVal"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { @@ -228,7 +229,7 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) { inclusion.parse(BSON("first" << true << "second" << true << "third" << true)); auto inputDoc = Document{{"second", 1}, {"first", 0}, {"third", 2}}; auto result = inclusion.applyProjection(inputDoc); - ASSERT_EQ(result, inputDoc); + ASSERT_DOCUMENT_EQ(result, inputDoc); } TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified) { @@ -237,7 +238,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}}); auto expectedResult = Document{{"firstComputed", "FIRST"}, {"secondComputed", "SECOND"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { @@ -245,12 +246,12 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) { inclusion.parse(BSON("a" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}}); auto expectedResult = Document{{"_id", "ID"}, {"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Should leave the "_id" in the same place as in the original document. result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}}); expectedResult = Document{{"a", 1}, {"_id", "ID"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFields) { @@ -258,7 +259,7 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFiel inclusion.parse(BSON("newField" << wrapInLiteral("computedVal"))); auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}}); auto expectedResult = Document{{"_id", "ID"}, {"newField", "computedVal"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { @@ -266,7 +267,7 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) { inclusion.parse(BSON("a" << true << "_id" << true << "b" << true)); auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}, {"c", 3}}); auto expectedResult = Document{{"_id", "ID"}, {"a", 1}, {"b", 2}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { @@ -274,7 +275,7 @@ TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) { inclusion.parse(BSON("a" << true << "_id" << false)); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}}); auto expectedResult = Document{{"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { @@ -282,7 +283,7 @@ TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) { inclusion.parse(BSON("_id" << wrapInLiteral("newId"))); auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}}); auto expectedResult = Document{{"_id", "newId"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } // @@ -296,22 +297,22 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc) // More than one field in sub document. auto result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}}); auto expectedResult = Document{{"a", Document{{"b", 1}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is the only field in the sub document. result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}}}}); expectedResult = Document{{"a", Document{{"b", 1}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is not present in the sub document. result = inclusion.applyProjection(Document{{"a", Document{{"c", 1}}}}); expectedResult = Document{{"a", Document{}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // There are no fields in sub document. result = inclusion.applyProjection(Document{{"a", Document{}}}); expectedResult = Document{{"a", Document{}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFieldDoesNotExist) { @@ -321,12 +322,12 @@ TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFiel // Should not add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Should not replace the first part of the path if that part exists. result = inclusion.applyProjection(Document{{"sub", "notADocument"}}); expectedResult = Document{}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementInArray) { @@ -350,7 +351,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementIn Value(vector<Value>{Value(), Value(Document{})})}; auto result = inclusion.applyProjection(Document{{"a", nestedValues}}); auto expectedResult = Document{{"a", expectedNestedValues}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument) { @@ -360,17 +361,17 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument // Other fields exist in sub document, one of which is the specified field. auto result = inclusion.applyProjection(Document{{"sub", Document{{"target", 1}, {"c", 2}}}}); auto expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Specified field is not present in the sub document. result = inclusion.applyProjection(Document{{"sub", Document{{"c", 1}}}}); expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // There are no fields in sub document. result = inclusion.applyProjection(Document{{"sub", Document{}}}); expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDoesntExist) { @@ -380,11 +381,11 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDo // Should add the path if it doesn't exist. auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Should replace non-documents with documents. result = inclusion.applyProjection(Document{{"sub", "notADocument"}}); - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayToComputedField) { @@ -395,11 +396,11 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayTo auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"a", Document{{"b", Document{{"c", Document{{"d", "computedVal"}}}}}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // Should replace non-documents with documents. result = inclusion.applyProjection(Document{{"a", Document{{"b", "other"}}}}); - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElementInArray) { @@ -421,7 +422,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElement Value(Document{{"b", "COMPUTED"}})})}; auto result = inclusion.applyProjection(Document{{"a", nestedValues}}); auto expectedResult = Document{{"a", expectedNestedValues}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachElementInArray) { @@ -448,7 +449,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachEl Value(Document{{"inc", 1}, {"comp", "COMPUTED"}})})}; auto result = inclusion.applyProjection(Document{{"a", nestedValues}}); auto expectedResult = Document{{"a", expectedNestedValues}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { @@ -456,7 +457,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) { inclusion.parse(BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW"))); auto result = inclusion.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}}); auto expectedResult = Document{{"_id", Document{{"X", 1}, {"Z", "NEW"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { @@ -479,7 +480,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) { {"X", "X"}, {"Y", "Y"}, {"Z", "Z"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpecified) { @@ -487,7 +488,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpe inclusion.parse(BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND"))); auto result = inclusion.applyProjection(Document{}); auto expectedResult = Document{{"a", "FIRST"}, {"b", Document{{"c", "SECOND"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusions) { @@ -495,10 +496,10 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio inclusion.parse(BSON("b.c" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", Document{{"c", "NEW"}}}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); result = inclusion.applyProjection(Document{{"a", 1}, {"b", 4}}); - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); // In this case, the field 'b' shows up first and has a nested inclusion or computed field. Even // though it is a computed field, it will appear first in the output document. This is @@ -506,7 +507,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio // recursively to each sub-document. result = inclusion.applyProjection(Document{{"b", 4}, {"a", 1}}); expectedResult = Document{{"b", Document{{"c", "NEW"}}}, {"a", 1}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppearAfterInclusions) { @@ -514,10 +515,10 @@ TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppea inclusion.parse(BSON("b" << wrapInLiteral("NEW") << "a" << true)); auto result = inclusion.applyProjection(Document{{"b", 1}, {"a", 1}}); auto expectedResult = Document{{"a", 1}, {"b", "NEW"}}; - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); result = inclusion.applyProjection(Document{{"a", 1}, {"b", 4}}); - ASSERT_EQ(result, expectedResult); + ASSERT_DOCUMENT_EQ(result, expectedResult); } // @@ -537,7 +538,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) MutableDocument expectedDoc(inputDoc); expectedDoc.copyMetaDataFrom(inputDoc); - ASSERT_EQ(result, expectedDoc.freeze()); + ASSERT_DOCUMENT_EQ(result, expectedDoc.freeze()); } } // namespace diff --git a/src/mongo/db/pipeline/pipeline.cpp b/src/mongo/db/pipeline/pipeline.cpp index d954991ce22..45b7ee58e51 100644 --- a/src/mongo/db/pipeline/pipeline.cpp +++ b/src/mongo/db/pipeline/pipeline.cpp @@ -221,11 +221,11 @@ void Pipeline::reattachToOperationContext(OperationContext* opCtx) { } } -void Pipeline::setCollator(std::unique_ptr<CollatorInterface> collator) { - pCtx->collator = std::move(collator); - - // TODO SERVER-23349: If the pipeline has any DocumentSourceMatch sources, ask them to - // re-parse their predicates. +void Pipeline::injectExpressionContext(const intrusive_ptr<ExpressionContext>& expCtx) { + pCtx = expCtx; + for (auto&& stage : _sources) { + stage->injectExpressionContext(pCtx); + } } intrusive_ptr<Pipeline> Pipeline::splitForSharded() { @@ -461,4 +461,5 @@ DepsTracker Pipeline::getDependencies(DepsTracker::MetadataAvailable metadataAva return deps; } + } // namespace mongo diff --git a/src/mongo/db/pipeline/pipeline.h b/src/mongo/db/pipeline/pipeline.h index d577b5fc151..531411fae60 100644 --- a/src/mongo/db/pipeline/pipeline.h +++ b/src/mongo/db/pipeline/pipeline.h @@ -107,15 +107,6 @@ public: void reattachToOperationContext(OperationContext* opCtx); /** - * Sets the collator on this Pipeline. parseCommand() will return a Pipeline with its collator - * already set from the parsed request (if applicable), but this setter method can be used to - * later override the Pipeline's original collator. - * - * The Pipeline's collator can be retrieved with getContext()->collator. - */ - void setCollator(std::unique_ptr<CollatorInterface> collator); - - /** Split the current Pipeline into a Pipeline for each shard, and a Pipeline that combines the results within mongos. @@ -142,6 +133,12 @@ public: void optimizePipeline(); /** + * Propagates a reference to the ExpressionContext to all of the pipeline's contained stages and + * expressions. + */ + void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx); + + /** * Returns any other collections involved in the pipeline in addition to the collection the * aggregation is run on. */ diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp index b8e7c1bcd1a..3ebc683676f 100644 --- a/src/mongo/db/pipeline/pipeline_d.cpp +++ b/src/mongo/db/pipeline/pipeline_d.cpp @@ -226,8 +226,8 @@ StatusWith<std::unique_ptr<PlanExecutor>> attemptToGetExecutor( // // If pipeline has a null collator (representing the "simple" collation), we simply set the // collation option to the original user BSON. - qr->setCollation(pExpCtx->collator ? pExpCtx->collator->getSpec().toBSON() - : pExpCtx->collation); + qr->setCollation(pExpCtx->getCollator() ? pExpCtx->getCollator()->getSpec().toBSON() + : pExpCtx->collation); const ExtensionsCallbackReal extensionsCallback(pExpCtx->opCtx, &pExpCtx->ns); diff --git a/src/mongo/db/pipeline/pipeline_test.cpp b/src/mongo/db/pipeline/pipeline_test.cpp index 4cdf064ace6..864680bcbb9 100644 --- a/src/mongo/db/pipeline/pipeline_test.cpp +++ b/src/mongo/db/pipeline/pipeline_test.cpp @@ -37,6 +37,7 @@ #include "mongo/db/pipeline/dependencies.h" #include "mongo/db/pipeline/document.h" #include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/db/pipeline/expression_context.h" #include "mongo/db/pipeline/field_path.h" #include "mongo/db/pipeline/pipeline.h" @@ -94,8 +95,9 @@ public: auto outputPipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx)); outputPipe->optimizePipeline(); - ASSERT_EQUALS(Value(outputPipe->writeExplainOps()), Value(outputPipeExpected["pipeline"])); - ASSERT_EQUALS(Value(outputPipe->serialize()), Value(serializePipeExpected["pipeline"])); + ASSERT_VALUE_EQ(Value(outputPipe->writeExplainOps()), + Value(outputPipeExpected["pipeline"])); + ASSERT_VALUE_EQ(Value(outputPipe->serialize()), Value(serializePipeExpected["pipeline"])); } virtual ~Base() {} @@ -744,8 +746,8 @@ public: shardPipe = mergePipe->splitForSharded(); ASSERT(shardPipe != nullptr); - ASSERT_EQUALS(Value(shardPipe->writeExplainOps()), Value(shardPipeExpected["pipeline"])); - ASSERT_EQUALS(Value(mergePipe->writeExplainOps()), Value(mergePipeExpected["pipeline"])); + ASSERT_VALUE_EQ(Value(shardPipe->writeExplainOps()), Value(shardPipeExpected["pipeline"])); + ASSERT_VALUE_EQ(Value(mergePipe->writeExplainOps()), Value(mergePipeExpected["pipeline"])); } virtual ~Base() {} @@ -1092,9 +1094,9 @@ TEST(PipelineInitialSource, ParseCollation) { ASSERT_OK(request.getStatus()); intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(opCtx.get(), request.getValue()); - ASSERT(ctx->collator.get()); + ASSERT(ctx->getCollator()); CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); - ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->collator.get(), &collator)); + ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->getCollator(), &collator)); } namespace Dependencies { diff --git a/src/mongo/db/pipeline/tee_buffer_test.cpp b/src/mongo/db/pipeline/tee_buffer_test.cpp index 9fb2b56f1fc..1f218ea1d89 100644 --- a/src/mongo/db/pipeline/tee_buffer_test.cpp +++ b/src/mongo/db/pipeline/tee_buffer_test.cpp @@ -31,6 +31,7 @@ #include "mongo/db/pipeline/tee_buffer.h" #include "mongo/db/pipeline/document.h" +#include "mongo/db/pipeline/document_value_test_util.h" #include "mongo/unittest/unittest.h" #include "mongo/util/assert_util.h" @@ -63,7 +64,7 @@ TEST(TeeBufferTest, ShouldProvideIteratorOverSingleDocument) { // Should be able to establish an iterator and get the document back. auto it = teeBuffer->begin(); ASSERT(it != teeBuffer->end()); - ASSERT_EQ(*it, inputDoc); + ASSERT_DOCUMENT_EQ(*it, inputDoc); ++it; ASSERT(it == teeBuffer->end()); } @@ -77,10 +78,10 @@ TEST(TeeBufferTest, ShouldProvideIteratorOverTwoDocuments) { auto it = teeBuffer->begin(); ASSERT(it != teeBuffer->end()); - ASSERT_EQ(*it, inputDocs.front()); + ASSERT_DOCUMENT_EQ(*it, inputDocs.front()); ++it; ASSERT(it != teeBuffer->end()); - ASSERT_EQ(*it, inputDocs.back()); + ASSERT_DOCUMENT_EQ(*it, inputDocs.back()); ++it; ASSERT(it == teeBuffer->end()); } @@ -97,18 +98,18 @@ TEST(TeeBufferTest, ShouldBeAbleToProvideMultipleIteratorsOverTheSameInputs) { // Advance both once. ASSERT(firstIt != teeBuffer->end()); - ASSERT_EQ(*firstIt, inputDocs.front()); + ASSERT_DOCUMENT_EQ(*firstIt, inputDocs.front()); ++firstIt; ASSERT(secondIt != teeBuffer->end()); - ASSERT_EQ(*secondIt, inputDocs.front()); + ASSERT_DOCUMENT_EQ(*secondIt, inputDocs.front()); ++secondIt; // Advance them both again. ASSERT(firstIt != teeBuffer->end()); - ASSERT_EQ(*firstIt, inputDocs.back()); + ASSERT_DOCUMENT_EQ(*firstIt, inputDocs.back()); ++firstIt; ASSERT(secondIt != teeBuffer->end()); - ASSERT_EQ(*secondIt, inputDocs.back()); + ASSERT_DOCUMENT_EQ(*secondIt, inputDocs.back()); ++secondIt; // Assert they've both reached the end. diff --git a/src/mongo/db/pipeline/value.cpp b/src/mongo/db/pipeline/value.cpp index 4334af8260c..2d76ebe2bb6 100644 --- a/src/mongo/db/pipeline/value.cpp +++ b/src/mongo/db/pipeline/value.cpp @@ -638,7 +638,9 @@ inline static int cmp(const T& left, const T& right) { } } -int Value::compare(const Value& rL, const Value& rR) { +int Value::compare(const Value& rL, + const Value& rR, + const StringData::ComparatorInterface* stringComparator) { // Note, this function needs to behave identically to BSON's compareElementValues(). // Additionally, any changes here must be replicated in hash_combine(). BSONType lType = rL.getType(); @@ -739,13 +741,20 @@ int Value::compare(const Value& rL, const Value& rR) { case jstOID: return memcmp(rL._storage.oid, rR._storage.oid, OID::kOIDSize); + case String: { + if (!stringComparator) { + return rL.getStringData().compare(rR.getStringData()); + } + + return stringComparator->compare(rL.getStringData(), rR.getStringData()); + } + case Code: case Symbol: - case String: return rL.getStringData().compare(rR.getStringData()); case Object: - return Document::compare(rL.getDocument(), rR.getDocument()); + return Document::compare(rL.getDocument(), rR.getDocument(), stringComparator); case Array: { const vector<Value>& lArr = rL.getArray(); @@ -754,7 +763,7 @@ int Value::compare(const Value& rL, const Value& rR) { const size_t elems = std::min(lArr.size(), rArr.size()); for (size_t i = 0; i < elems; i++) { // compare the two corresponding elements - ret = Value::compare(lArr[i], rArr[i]); + ret = Value::compare(lArr[i], rArr[i], stringComparator); if (ret) return ret; // values are unequal } diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h index bda0723d471..e151e58c813 100644 --- a/src/mongo/db/pipeline/value.h +++ b/src/mongo/db/pipeline/value.h @@ -28,7 +28,7 @@ #pragma once - +#include "mongo/base/string_data.h" #include "mongo/db/pipeline/value_internal.h" #include "mongo/platform/unordered_set.h" @@ -57,6 +57,28 @@ class BSONElement; */ class Value { public: + /** + * Operator overloads for relops return a DeferredComparison which can subsequently be evaluated + * by a ValueComparator. + */ + struct DeferredComparison { + enum class Type { + kLT, + kLTE, + kEQ, + kGT, + kGTE, + kNE, + }; + + DeferredComparison(Type type, const Value& lhs, const Value& rhs) + : type(type), lhs(lhs), rhs(rhs) {} + + Type type; + const Value& lhs; + const Value& rhs; + }; + /** Construct a Value * * All types not listed will be rejected rather than converted (see private for why) @@ -200,28 +222,54 @@ public: time_t coerceToTimeT() const; tm coerceToTm() const; // broken-out time struct (see man gmtime) + // + // Comparison API. + // + // Value instances can be compared either using Value::compare() or via operator overloads. + // Most callers should prefer operator overloads. Note that the operator overloads return a + // DeferredComparison, which must be subsequently evaluated by a ValueComparator. See + // value_comparator.h for details. + // - /** Compare two Values. + /** + * Compare two Values. Most Values should prefer to use ValueComparator instead. See + * value_comparator.h for details. + * + * Pass a non-null StringData::ComparatorInterface if special string comparison semantics are + * required. If the comparator is null, then a simple binary compare is used for strings. This + * comparator is only used for string *values*; field names are always compared using simple + * binary compare. + * * @returns an integer less than zero, zero, or an integer greater than * zero, depending on whether lhs < rhs, lhs == rhs, or lhs > rhs * Warning: may return values other than -1, 0, or 1 */ - static int compare(const Value& lhs, const Value& rhs); - - friend bool operator==(const Value& v1, const Value& v2) { - if (v1._storage.identical(v2._storage)) { - // Simple case - return true; - } - return (Value::compare(v1, v2) == 0); + static int compare(const Value& lhs, + const Value& rhs, + const StringData::ComparatorInterface* stringComparator = nullptr); + + friend DeferredComparison operator==(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kEQ, lhs, rhs); + } + + friend DeferredComparison operator!=(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kNE, lhs, rhs); + } + + friend DeferredComparison operator<(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kLT, lhs, rhs); + } + + friend DeferredComparison operator<=(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kLTE, lhs, rhs); } - friend bool operator!=(const Value& v1, const Value& v2) { - return !(v1 == v2); + friend DeferredComparison operator>(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kGT, lhs, rhs); } - friend bool operator<(const Value& lhs, const Value& rhs) { - return (Value::compare(lhs, rhs) < 0); + friend DeferredComparison operator>=(const Value& lhs, const Value& rhs) { + return DeferredComparison(DeferredComparison::Type::kGTE, lhs, rhs); } /// This is for debugging, logging, etc. See getString() for how to extract a string. @@ -289,8 +337,6 @@ private: }; static_assert(sizeof(Value) == 16, "sizeof(Value) == 16"); -typedef unordered_set<Value, Value::Hash> ValueSet; - inline void swap(mongo::Value& lhs, mongo::Value& rhs) { lhs.swap(rhs); } diff --git a/src/mongo/db/pipeline/value_comparator.cpp b/src/mongo/db/pipeline/value_comparator.cpp new file mode 100644 index 00000000000..a20a9460b6d --- /dev/null +++ b/src/mongo/db/pipeline/value_comparator.cpp @@ -0,0 +1,57 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/value_comparator.h" + +#include "mongo/util/assert_util.h" + +namespace mongo { + +bool ValueComparator::evaluate(Value::DeferredComparison deferredComparison) const { + int cmp = Value::compare(deferredComparison.lhs, deferredComparison.rhs, _stringComparator); + switch (deferredComparison.type) { + case Value::DeferredComparison::Type::kLT: + return cmp < 0; + case Value::DeferredComparison::Type::kLTE: + return cmp <= 0; + case Value::DeferredComparison::Type::kEQ: + return cmp == 0; + case Value::DeferredComparison::Type::kGTE: + return cmp >= 0; + case Value::DeferredComparison::Type::kGT: + return cmp > 0; + case Value::DeferredComparison::Type::kNE: + return cmp != 0; + } + + MONGO_UNREACHABLE; +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/value_comparator.h b/src/mongo/db/pipeline/value_comparator.h new file mode 100644 index 00000000000..cbf423ccf92 --- /dev/null +++ b/src/mongo/db/pipeline/value_comparator.h @@ -0,0 +1,177 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include <map> +#include <set> +#include <unordered_map> +#include <unordered_set> + +#include "mongo/base/string_data.h" +#include "mongo/db/pipeline/value.h" + +namespace mongo { + +class ValueComparator { +public: + /** + * Functor compatible for use with unordered STL containers. + * + * TODO SERVER-23349: Remove the no-arguments constructor. + */ + class EqualTo { + public: + EqualTo() = default; + + explicit EqualTo(const ValueComparator* comparator) : _comparator(comparator) {} + + bool operator()(const Value& lhs, const Value& rhs) const { + return _comparator ? _comparator->compare(lhs, rhs) == 0 + : ValueComparator().compare(lhs, rhs) == 0; + } + + private: + const ValueComparator* _comparator = nullptr; + }; + + /** + * Functor compatible for use with ordered STL containers. + */ + class LessThan { + public: + explicit LessThan(const ValueComparator* comparator) : _comparator(comparator) {} + + bool operator()(const Value& lhs, const Value& rhs) const { + return _comparator->compare(lhs, rhs) < 0; + } + + private: + const ValueComparator* _comparator; + }; + + /** + * Constructs a value comparator with simple comparison semantics. + */ + ValueComparator() = default; + + /** + * Constructs a value comparator with special string comparison semantics. + */ + ValueComparator(const StringData::ComparatorInterface* stringComparator) + : _stringComparator(stringComparator) {} + + /** + * Returns <0 if 'lhs' is less than 'rhs', 0 if 'lhs' is equal to 'rhs', and >0 if 'lhs' is + * greater than 'rhs'. + */ + int compare(const Value& lhs, const Value& rhs) const { + return Value::compare(lhs, rhs, _stringComparator); + } + + /** + * Evaluates a deferred comparison object that was generated by invoking one of the comparison + * operators on the Value class. + */ + bool evaluate(Value::DeferredComparison deferredComparison) const; + + /** + * Returns a function object which computes whether one Value is equal to another under this + * comparator. This comparator must outlive the returned function object. + */ + EqualTo getEqualTo() const { + return EqualTo(this); + } + + /** + * Returns a function object which computes whether one Value is less than another under this + * comparator. This comparator must outlive the returned function object. + */ + LessThan getLessThan() const { + return LessThan(this); + } + + /** + * Construct an empty ordered set of Value whose ordering and equivalence classes are given by + * this comparator. This comparator must outlive the returned set. + */ + std::set<Value, LessThan> makeOrderedValueSet() const { + return std::set<Value, LessThan>(LessThan(this)); + } + + /** + * Construct an empty unordered set of Value whose equivalence classes are given by this + * comparator. This comparator must outlive the returned set. + * + * TODO SERVER-23990: Make Value::Hash use the collation. The returned set won't be correctly + * collation-aware until this work is done. + */ + std::unordered_set<Value, Value::Hash, EqualTo> makeUnorderedValueSet() const { + return std::unordered_set<Value, Value::Hash, EqualTo>(0, Value::Hash(), EqualTo(this)); + } + + /** + * Construct an empty ordered map from Value to type T whose ordering and equivalence classes + * are given by this comparator. This comparator must outlive the returned set. + */ + template <typename T> + std::map<Value, T, LessThan> makeOrderedValueMap() const { + return std::map<Value, T, LessThan>(LessThan(this)); + } + + /** + * Construct an empty unordered map from Value to type T whose equivalence classes are given by + * this comparator. This comparator must outlive the returned set. + * + * TODO SERVER-23990: Make Value::Hash use the collation. The returned map won't be correctly + * collation-aware until this work is done. + */ + template <typename T> + std::unordered_map<Value, T, Value::Hash, EqualTo> makeUnorderedValueMap() const { + return std::unordered_map<Value, T, Value::Hash, EqualTo>(0, Value::Hash(), EqualTo(this)); + } + +private: + const StringData::ComparatorInterface* _stringComparator = nullptr; +}; + +// +// Type aliases for sets and maps of Value for use by clients of the Document/Value library. +// + +using ValueSet = std::set<Value, ValueComparator::LessThan>; + +using ValueUnorderedSet = std::unordered_set<Value, Value::Hash, ValueComparator::EqualTo>; + +template <typename T> +using ValueMap = std::map<Value, T, ValueComparator::LessThan>; + +template <typename T> +using ValueUnorderedMap = std::unordered_map<Value, T, Value::Hash, ValueComparator::EqualTo>; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/value_comparator_test.cpp b/src/mongo/db/pipeline/value_comparator_test.cpp new file mode 100644 index 00000000000..05386061680 --- /dev/null +++ b/src/mongo/db/pipeline/value_comparator_test.cpp @@ -0,0 +1,223 @@ +/** + * Copyright (C) 2016 MongoDB Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License, version 3, + * as published by the Free Software Foundation. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the GNU Affero General Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/pipeline/value_comparator.h" + +#include "mongo/db/pipeline/document_value_test_util.h" +#include "mongo/db/query/collation/collator_interface_mock.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +TEST(ValueComparatorTest, EqualToEvaluatesCorrectly) { + Value val1("bar"); + Value val2("bar"); + Value val3("baz"); + ASSERT_TRUE(ValueComparator().evaluate(val1 == val2)); + ASSERT_FALSE(ValueComparator().evaluate(val1 == val3)); +} + +TEST(ValueComparatorTest, EqualToEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + Value val1("abc"); + Value val2("def"); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2)); +} + +TEST(ValueComparatorTest, EqualToFunctorEvaluatesCorrectly) { + ValueComparator valueComparator; + auto equalFunc = valueComparator.getEqualTo(); + Value val1("bar"); + Value val2("bar"); + Value val3("baz"); + ASSERT_TRUE(equalFunc(val1, val2)); + ASSERT_FALSE(equalFunc(val1, val3)); +} + +TEST(ValueComparatorTest, EqualToFunctorEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + ValueComparator valueComparator(&collator); + auto equalFunc = valueComparator.getEqualTo(); + Value val1("abc"); + Value val2("def"); + ASSERT_TRUE(equalFunc(val1, val2)); +} + +TEST(ValueComparatorTest, NotEqualEvaluatesCorrectly) { + Value val1("bar"); + Value val2("bar"); + Value val3("baz"); + ASSERT_FALSE(ValueComparator().evaluate(val1 != val2)); + ASSERT_TRUE(ValueComparator().evaluate(val1 != val3)); +} + +TEST(ValueComparatorTest, NotEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + Value val1("abc"); + Value val2("def"); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val1 != val2)); +} + +TEST(ValueComparatorTest, LessThanEvaluatesCorrectly) { + Value val1("a"); + Value val2("b"); + ASSERT_TRUE(ValueComparator().evaluate(val1 < val2)); + ASSERT_FALSE(ValueComparator().evaluate(val2 < val1)); +} + +TEST(ValueComparatorTest, LessThanEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + Value val1("za"); + Value val2("yb"); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 < val2)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val2 < val1)); +} + +TEST(ValueComparatorTest, LessThanFunctorEvaluatesCorrectly) { + ValueComparator valueComparator; + auto lessThanFunc = valueComparator.getLessThan(); + Value val1("a"); + Value val2("b"); + ASSERT_TRUE(lessThanFunc(val1, val2)); + ASSERT_FALSE(lessThanFunc(val2, val1)); +} + +TEST(ValueComparatorTest, LessThanFunctorEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + ValueComparator valueComparator(&collator); + auto lessThanFunc = valueComparator.getLessThan(); + Value val1("za"); + Value val2("yb"); + ASSERT_TRUE(lessThanFunc(val1, val2)); + ASSERT_FALSE(lessThanFunc(val2, val1)); +} + +TEST(ValueComparatorTest, LessThanOrEqualEvaluatesCorrectly) { + Value val1("a"); + Value val2("a"); + Value val3("b"); + ASSERT_TRUE(ValueComparator().evaluate(val1 <= val2)); + ASSERT_TRUE(ValueComparator().evaluate(val2 <= val1)); + ASSERT_TRUE(ValueComparator().evaluate(val1 <= val3)); + ASSERT_FALSE(ValueComparator().evaluate(val3 <= val1)); +} + +TEST(ValueComparatorTest, LessThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + Value val1("za"); + Value val2("za"); + Value val3("yb"); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 <= val2)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 <= val1)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 <= val3)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 <= val1)); +} + +TEST(ValueComparatorTest, GreaterThanEvaluatesCorrectly) { + Value val1("b"); + Value val2("a"); + ASSERT_TRUE(ValueComparator().evaluate(val1 > val2)); + ASSERT_FALSE(ValueComparator().evaluate(val2 > val1)); +} + +TEST(ValueComparatorTest, GreaterThanEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + Value val1("yb"); + Value val2("za"); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 > val2)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val2 > val1)); +} + +TEST(ValueComparatorTest, GreaterThanOrEqualEvaluatesCorrectly) { + Value val1("b"); + Value val2("b"); + Value val3("a"); + ASSERT_TRUE(ValueComparator().evaluate(val1 >= val2)); + ASSERT_TRUE(ValueComparator().evaluate(val2 >= val1)); + ASSERT_TRUE(ValueComparator().evaluate(val1 >= val3)); + ASSERT_FALSE(ValueComparator().evaluate(val3 >= val1)); +} + +TEST(ValueComparatorTest, GreaterThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + Value val1("yb"); + Value val2("yb"); + Value val3("za"); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 >= val2)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 >= val1)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 >= val3)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 >= val1)); +} + +TEST(ValueComparatorTest, OrderedValueSetRespectsTheComparator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString); + ValueComparator valueComparator(&collator); + ValueSet set = valueComparator.makeOrderedValueSet(); + set.insert(Value("yb")); + set.insert(Value("za")); + + auto it = set.begin(); + ASSERT_VALUE_EQ(*it, Value("za")); + ++it; + ASSERT_VALUE_EQ(*it, Value("yb")); + ++it; + ASSERT(it == set.end()); +} + +TEST(ValueComparatorTest, EqualToEvaluatesCorrectlyWithNumbers) { + Value val1(88); + Value val2(88); + Value val3(99); + ASSERT_TRUE(ValueComparator().evaluate(val1 == val2)); + ASSERT_FALSE(ValueComparator().evaluate(val1 == val3)); +} + +TEST(ValueComparatorTest, NestedObjectEqualityRespectsCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + Value val1(Document{{"foo", "abc"}}); + Value val2(Document{{"foo", "def"}}); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 == val1)); +} + +TEST(ValueComparatorTest, NestedArrayEqualityRespectsCollator) { + CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual); + Value val1(std::vector<Value>{Value("a"), Value("b")}); + Value val2(std::vector<Value>{Value("c"), Value("d")}); + Value val3(std::vector<Value>{Value("c"), Value("d"), Value("e")}); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2)); + ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 == val1)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val1 == val3)); + ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 == val1)); +} + +} // namespace +} // namespace mongo |