summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/commands/pipeline_command.cpp35
-rw-r--r--src/mongo/db/pipeline/SConscript41
-rw-r--r--src/mongo/db/pipeline/accumulator.h34
-rw-r--r--src/mongo/db/pipeline/accumulator_add_to_set.cpp13
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp7
-rw-r--r--src/mongo/db/pipeline/aggregation_request_test.cpp7
-rw-r--r--src/mongo/db/pipeline/document.cpp6
-rw-r--r--src/mongo/db/pipeline/document.h75
-rw-r--r--src/mongo/db/pipeline/document_comparator.cpp57
-rw-r--r--src/mongo/db/pipeline/document_comparator.h59
-rw-r--r--src/mongo/db/pipeline/document_comparator_test.cpp169
-rw-r--r--src/mongo/db/pipeline/document_source.h76
-rw-r--r--src/mongo/db/pipeline/document_source_bucket.cpp16
-rw-r--r--src/mongo/db/pipeline/document_source_cursor.cpp10
-rw-r--r--src/mongo/db/pipeline/document_source_facet.cpp6
-rw-r--r--src/mongo/db/pipeline/document_source_facet.h5
-rw-r--r--src/mongo/db/pipeline/document_source_facet_test.cpp19
-rw-r--r--src/mongo/db/pipeline/document_source_geo_near.cpp4
-rw-r--r--src/mongo/db/pipeline/document_source_graph_lookup.cpp53
-rw-r--r--src/mongo/db/pipeline/document_source_group.cpp56
-rw-r--r--src/mongo/db/pipeline/document_source_limit.cpp4
-rw-r--r--src/mongo/db/pipeline/document_source_match.cpp11
-rw-r--r--src/mongo/db/pipeline/document_source_merge_cursors.cpp5
-rw-r--r--src/mongo/db/pipeline/document_source_project.cpp5
-rw-r--r--src/mongo/db/pipeline/document_source_redact.cpp11
-rw-r--r--src/mongo/db/pipeline/document_source_sample.cpp4
-rw-r--r--src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp13
-rw-r--r--src/mongo/db/pipeline/document_source_skip.cpp1
-rw-r--r--src/mongo/db/pipeline/document_source_sort.cpp1
-rw-r--r--src/mongo/db/pipeline/document_source_test.cpp78
-rw-r--r--src/mongo/db/pipeline/document_source_unwind.cpp12
-rw-r--r--src/mongo/db/pipeline/document_value_test.cpp45
-rw-r--r--src/mongo/db/pipeline/document_value_test_util.cpp67
-rw-r--r--src/mongo/db/pipeline/document_value_test_util.h88
-rw-r--r--src/mongo/db/pipeline/document_value_test_util_self_test.cpp92
-rw-r--r--src/mongo/db/pipeline/expression.cpp129
-rw-r--r--src/mongo/db/pipeline/expression.h52
-rw-r--r--src/mongo/db/pipeline/expression_context.cpp11
-rw-r--r--src/mongo/db/pipeline/expression_context.h40
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp164
-rw-r--r--src/mongo/db/pipeline/lookup_set_cache.h13
-rw-r--r--src/mongo/db/pipeline/parsed_aggregation_projection.h9
-rw-r--r--src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp47
-rw-r--r--src/mongo/db/pipeline/parsed_inclusion_projection.cpp10
-rw-r--r--src/mongo/db/pipeline/parsed_inclusion_projection.h7
-rw-r--r--src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp97
-rw-r--r--src/mongo/db/pipeline/pipeline.cpp11
-rw-r--r--src/mongo/db/pipeline/pipeline.h15
-rw-r--r--src/mongo/db/pipeline/pipeline_d.cpp4
-rw-r--r--src/mongo/db/pipeline/pipeline_test.cpp14
-rw-r--r--src/mongo/db/pipeline/tee_buffer_test.cpp15
-rw-r--r--src/mongo/db/pipeline/value.cpp17
-rw-r--r--src/mongo/db/pipeline/value.h78
-rw-r--r--src/mongo/db/pipeline/value_comparator.cpp57
-rw-r--r--src/mongo/db/pipeline/value_comparator.h177
-rw-r--r--src/mongo/db/pipeline/value_comparator_test.cpp223
-rw-r--r--src/mongo/dbtests/SConscript3
-rw-r--r--src/mongo/dbtests/documentsourcetests.cpp7
58 files changed, 1966 insertions, 419 deletions
diff --git a/src/mongo/db/commands/pipeline_command.cpp b/src/mongo/db/commands/pipeline_command.cpp
index ae0a5f80f1b..97aff68c473 100644
--- a/src/mongo/db/commands/pipeline_command.cpp
+++ b/src/mongo/db/commands/pipeline_command.cpp
@@ -188,6 +188,7 @@ boost::intrusive_ptr<Pipeline> reparsePipeline(
fassertFailedWithStatusNoTrace(40175, reparsedPipeline.getStatus());
}
+ reparsedPipeline.getValue()->injectExpressionContext(expCtx);
reparsedPipeline.getValue()->optimizePipeline();
return reparsedPipeline.getValue();
}
@@ -262,15 +263,6 @@ public:
return appendCommandStatus(result, statusWithPipeline.getStatus());
}
auto pipeline = std::move(statusWithPipeline.getValue());
- pipeline->optimizePipeline();
-
- if (kDebugBuild && !expCtx->isExplain && !expCtx->inShard) {
- // Make sure all operations round-trip through Pipeline::serialize() correctly by
- // re-parsing every command in debug builds. This is important because sharded
- // aggregations rely on this ability. Skipping when inShard because this has already
- // been through the transformation (and this un-sets expCtx->inShard).
- pipeline = reparsePipeline(pipeline, request.getValue(), expCtx);
- }
unique_ptr<ClientCursorPin> pin; // either this OR the exec will be non-null
unique_ptr<PlanExecutor> exec;
@@ -307,11 +299,30 @@ public:
return this->run(txn, db, viewCmd, options, errmsg, result);
}
- // If the pipeline does not have a user-specified collation, set it from the
- // collection default.
+ // If the pipeline does not have a user-specified collation, set it from the collection
+ // default.
if (request.getValue().getCollation().isEmpty() && collection &&
collection->getDefaultCollator()) {
- pipeline->setCollator(collection->getDefaultCollator()->clone());
+ invariant(!expCtx->getCollator());
+ expCtx->setCollator(collection->getDefaultCollator()->clone());
+ }
+
+ // Propagate the ExpressionContext throughout all of the pipeline's stages and
+ // expressions.
+ pipeline->injectExpressionContext(expCtx);
+
+ // The pipeline must be optimized after the correct collator has been set on it (by
+ // injecting the ExpressionContext containing the collator). This is necessary because
+ // optimization may make string comparisons, e.g. optimizing {$eq: [<str1>, <str2>]} to
+ // a constant.
+ pipeline->optimizePipeline();
+
+ if (kDebugBuild && !expCtx->isExplain && !expCtx->inShard) {
+ // Make sure all operations round-trip through Pipeline::serialize() correctly by
+ // re-parsing every command in debug builds. This is important because sharded
+ // aggregations rely on this ability. Skipping when inShard because this has
+ // already been through the transformation (and this un-sets expCtx->inShard).
+ pipeline = reparsePipeline(pipeline, request.getValue(), expCtx);
}
// This does mongod-specific stuff like creating the input PlanExecutor and adding
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript
index 28d8aa0361f..f1a36c2bedd 100644
--- a/src/mongo/db/pipeline/SConscript
+++ b/src/mongo/db/pipeline/SConscript
@@ -37,7 +37,9 @@ env.Library(
target='document_value',
source=[
'document.cpp',
+ 'document_comparator.cpp',
'value.cpp',
+ 'value_comparator.cpp',
],
LIBDEPS=[
'field_path',
@@ -47,11 +49,29 @@ env.Library(
]
)
+env.Library(
+ target='document_value_test_util',
+ source=[
+ 'document_value_test_util.cpp',
+ ],
+ LIBDEPS=[
+ '$BUILD_DIR/mongo/unittest/unittest',
+ 'document_value',
+ ],
+)
+
env.CppUnitTest(
target='document_value_test',
- source='document_value_test.cpp',
+ source=[
+ 'document_comparator_test.cpp',
+ 'document_value_test.cpp',
+ 'document_value_test_util_self_test.cpp',
+ 'value_comparator_test.cpp',
+ ],
LIBDEPS=[
+ '$BUILD_DIR/mongo/db/query/collation/collator_interface_mock',
'document_value',
+ 'document_value_test_util',
],
)
@@ -75,7 +95,8 @@ env.CppUnitTest(
target='aggregation_request_test',
source='aggregation_request_test.cpp',
LIBDEPS=[
- 'aggregation_request',
+ 'aggregation_request',
+ 'document_value_test_util',
],
)
@@ -97,6 +118,7 @@ env.CppUnitTest(
source='document_source_test.cpp',
LIBDEPS=[
'document_source',
+ 'document_value_test_util',
'$BUILD_DIR/mongo/db/service_context',
'$BUILD_DIR/mongo/util/clock_source_mock',
'$BUILD_DIR/mongo/executor/thread_pool_task_executor',
@@ -124,6 +146,7 @@ env.Library(
LIBDEPS=[
'dependencies',
'document_value',
+ 'expression_context',
'$BUILD_DIR/mongo/util/summation',
]
)
@@ -182,7 +205,6 @@ docSourceEnv.Library(
'dependencies',
'document_value',
'expression',
- 'expression_context',
'parsed_aggregation_projection',
'$BUILD_DIR/mongo/client/clientdriver',
'$BUILD_DIR/mongo/db/bson/dotted_path_support',
@@ -238,10 +260,11 @@ env.CppUnitTest(
target='document_source_facet_test',
source='document_source_facet_test.cpp',
LIBDEPS=[
- 'document_source_facet',
'$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init',
'$BUILD_DIR/mongo/db/query/query_test_service_context',
'$BUILD_DIR/mongo/db/service_context_noop_init',
+ 'document_source_facet',
+ 'document_value_test_util',
],
)
@@ -249,9 +272,10 @@ env.CppUnitTest(
target='tee_buffer_test',
source='tee_buffer_test.cpp',
LIBDEPS=[
- 'document_source_facet',
'$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init',
'$BUILD_DIR/mongo/db/service_context_noop_init',
+ 'document_source_facet',
+ 'document_value_test_util',
],
)
@@ -260,6 +284,7 @@ env.CppUnitTest(
source='expression_test.cpp',
LIBDEPS=[
'accumulator',
+ 'document_value_test_util',
'expression',
],
)
@@ -269,6 +294,7 @@ env.CppUnitTest(
source='accumulator_test.cpp',
LIBDEPS=[
'accumulator',
+ 'document_value_test_util',
],
)
@@ -276,12 +302,13 @@ env.CppUnitTest(
target='pipeline_test',
source='pipeline_test.cpp',
LIBDEPS=[
- 'pipeline',
'$BUILD_DIR/mongo/db/auth/authorization_manager_mock_init',
'$BUILD_DIR/mongo/db/query/collation/collator_interface_mock',
'$BUILD_DIR/mongo/db/query/query_test_service_context',
'$BUILD_DIR/mongo/db/service_context',
'$BUILD_DIR/mongo/db/service_context_noop_init',
+ 'document_value_test_util',
+ 'pipeline',
],
)
@@ -314,6 +341,7 @@ env.CppUnitTest(
target='parsed_exclusion_projection_test',
source='parsed_exclusion_projection_test.cpp',
LIBDEPS=[
+ 'document_value_test_util',
'parsed_aggregation_projection',
],
)
@@ -330,6 +358,7 @@ env.CppUnitTest(
target='parsed_inclusion_projection_test',
source='parsed_inclusion_projection_test.cpp',
LIBDEPS=[
+ 'document_value_test_util',
'parsed_aggregation_projection',
],
)
diff --git a/src/mongo/db/pipeline/accumulator.h b/src/mongo/db/pipeline/accumulator.h
index eb09459f942..81b84bf1982 100644
--- a/src/mongo/db/pipeline/accumulator.h
+++ b/src/mongo/db/pipeline/accumulator.h
@@ -31,12 +31,15 @@
#include "mongo/platform/basic.h"
#include <boost/intrusive_ptr.hpp>
+#include <boost/optional.hpp>
#include <unordered_set>
#include <vector>
#include "mongo/base/init.h"
#include "mongo/bson/bsontypes.h"
+#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/stdx/functional.h"
#include "mongo/util/summation.h"
@@ -108,12 +111,35 @@ public:
return false;
}
+ /**
+ * Injects the ExpressionContext so that it may be used during evaluation of the Accumulator.
+ * Construction of accumulators is done at parse time, but the ExpressionContext isn't finalized
+ * until later, at which point it is injected using this method.
+ */
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {
+ _expCtx = expCtx;
+ doInjectExpressionContext();
+ }
+
protected:
/// Update subclass's internal state based on input
virtual void processInternal(const Value& input, bool merging) = 0;
+ /**
+ * Accumulators which need to update their internal state when attaching to a new
+ * ExpressionContext should override this method.
+ */
+ virtual void doInjectExpressionContext() {}
+
+ const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const {
+ return _expCtx;
+ }
+
/// subclasses are expected to update this as necessary
int _memUsageBytes = 0;
+
+private:
+ boost::intrusive_ptr<ExpressionContext> _expCtx;
};
@@ -136,9 +162,13 @@ public:
return true;
}
+ void doInjectExpressionContext() final;
+
private:
- typedef std::unordered_set<Value, Value::Hash> SetType;
- SetType set;
+ // We use boost::optional to defer initialization until the ExpressionContext containing the
+ // correct comparator is injected, since this set must use the comparator's definition of
+ // equality.
+ boost::optional<ValueUnorderedSet> _set;
};
diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp
index 313624bfdf0..7c780fcc7ee 100644
--- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp
+++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp
@@ -46,7 +46,7 @@ const char* AccumulatorAddToSet::getOpName() const {
void AccumulatorAddToSet::processInternal(const Value& input, bool merging) {
if (!merging) {
if (!input.missing()) {
- bool inserted = set.insert(input).second;
+ bool inserted = _set->insert(input).second;
if (inserted) {
_memUsageBytes += input.getApproximateSize();
}
@@ -60,7 +60,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) {
const vector<Value>& array = input.getArray();
for (size_t i = 0; i < array.size(); i++) {
- bool inserted = set.insert(array[i]).second;
+ bool inserted = _set->insert(array[i]).second;
if (inserted) {
_memUsageBytes += array[i].getApproximateSize();
}
@@ -69,7 +69,7 @@ void AccumulatorAddToSet::processInternal(const Value& input, bool merging) {
}
Value AccumulatorAddToSet::getValue(bool toBeMerged) const {
- return Value(vector<Value>(set.begin(), set.end()));
+ return Value(vector<Value>(_set->begin(), _set->end()));
}
AccumulatorAddToSet::AccumulatorAddToSet() {
@@ -77,11 +77,16 @@ AccumulatorAddToSet::AccumulatorAddToSet() {
}
void AccumulatorAddToSet::reset() {
- SetType().swap(set);
+ _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet();
_memUsageBytes = sizeof(*this);
}
intrusive_ptr<Accumulator> AccumulatorAddToSet::create() {
return new AccumulatorAddToSet();
}
+
+void AccumulatorAddToSet::doInjectExpressionContext() {
+ _set = getExpressionContext()->getValueComparator().makeUnorderedValueSet();
}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index e354d71d8f9..8fad2b0095d 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -30,6 +30,7 @@
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/dbtests/dbtests.h"
@@ -57,7 +58,7 @@ static void assertExpectedResults(
accum->process(val, false);
}
Value result = accum->getValue(false);
- ASSERT_EQUALS(op.second, result);
+ ASSERT_VALUE_EQ(op.second, result);
ASSERT_EQUALS(op.second.getType(), result.getType());
}
@@ -70,7 +71,7 @@ static void assertExpectedResults(
}
accum->process(shard->getValue(true), true);
Value result = accum->getValue(false);
- ASSERT_EQUALS(op.second, result);
+ ASSERT_VALUE_EQ(op.second, result);
ASSERT_EQUALS(op.second.getType(), result.getType());
}
@@ -83,7 +84,7 @@ static void assertExpectedResults(
accum->process(shard->getValue(true), true);
}
Value result = accum->getValue(false);
- ASSERT_EQUALS(op.second, result);
+ ASSERT_VALUE_EQ(op.second, result);
ASSERT_EQUALS(op.second.getType(), result.getType());
}
} catch (...) {
diff --git a/src/mongo/db/pipeline/aggregation_request_test.cpp b/src/mongo/db/pipeline/aggregation_request_test.cpp
index e7dc1dd0c4e..b9fe685961f 100644
--- a/src/mongo/db/pipeline/aggregation_request_test.cpp
+++ b/src/mongo/db/pipeline/aggregation_request_test.cpp
@@ -36,6 +36,7 @@
#include "mongo/db/catalog/document_validation.h"
#include "mongo/db/namespace_string.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
@@ -75,7 +76,7 @@ TEST(AggregationRequestTest, ShouldOnlySerializeRequiredFieldsIfNoOptionalFields
auto expectedSerialization =
Document{{AggregationRequest::kCommandName, nss.coll()},
{AggregationRequest::kPipelineName, Value(std::vector<Value>{})}};
- ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization);
+ ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization);
}
TEST(AggregationRequestTest, ShouldNotSerializeOptionalValuesIfEquivalentToDefault) {
@@ -90,7 +91,7 @@ TEST(AggregationRequestTest, ShouldNotSerializeOptionalValuesIfEquivalentToDefau
auto expectedSerialization =
Document{{AggregationRequest::kCommandName, nss.coll()},
{AggregationRequest::kPipelineName, Value(std::vector<Value>{})}};
- ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization);
+ ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization);
}
TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) {
@@ -112,7 +113,7 @@ TEST(AggregationRequestTest, ShouldSerializeOptionalValuesIfSet) {
{AggregationRequest::kFromRouterName, true},
{bypassDocumentValidationCommandOption(), true},
{AggregationRequest::kCollationName, collationObj}};
- ASSERT_EQ(request.serializeToCommandObj(), expectedSerialization);
+ ASSERT_DOCUMENT_EQ(request.serializeToCommandObj(), expectedSerialization);
}
TEST(AggregationRequestTest, ShouldSetBatchSizeToDefaultOnEmptyCursorObject) {
diff --git a/src/mongo/db/pipeline/document.cpp b/src/mongo/db/pipeline/document.cpp
index 7c1d04f1a9f..df2eb8facdf 100644
--- a/src/mongo/db/pipeline/document.cpp
+++ b/src/mongo/db/pipeline/document.cpp
@@ -381,7 +381,9 @@ void Document::hash_combine(size_t& seed) const {
}
}
-int Document::compare(const Document& rL, const Document& rR) {
+int Document::compare(const Document& rL,
+ const Document& rR,
+ const StringData::ComparatorInterface* stringComparator) {
DocumentStorageIterator lIt = rL.storage().iterator();
DocumentStorageIterator rIt = rR.storage().iterator();
@@ -410,7 +412,7 @@ int Document::compare(const Document& rL, const Document& rR) {
if (nameCmp)
return nameCmp; // field names are unequal
- const int valueCmp = Value::compare(lField.val, rField.val);
+ const int valueCmp = Value::compare(lField.val, rField.val, stringComparator);
if (valueCmp)
return valueCmp; // fields are unequal
diff --git a/src/mongo/db/pipeline/document.h b/src/mongo/db/pipeline/document.h
index 840c81315f0..0a0220964bd 100644
--- a/src/mongo/db/pipeline/document.h
+++ b/src/mongo/db/pipeline/document.h
@@ -33,6 +33,7 @@
#include <boost/functional/hash.hpp>
#include <boost/intrusive_ptr.hpp>
+#include "mongo/base/string_data.h"
#include "mongo/bson/util/builder.h"
namespace mongo {
@@ -66,6 +67,28 @@ class Position;
*/
class Document {
public:
+ /**
+ * Operator overloads for relops return a DeferredComparison which can subsequently be evaluated
+ * by a DocumentComparator.
+ */
+ struct DeferredComparison {
+ enum class Type {
+ kLT,
+ kLTE,
+ kEQ,
+ kGT,
+ kGTE,
+ kNE,
+ };
+
+ DeferredComparison(Type type, const Document& lhs, const Document& rhs)
+ : type(type), lhs(lhs), rhs(rhs) {}
+
+ Type type;
+ const Document& lhs;
+ const Document& rhs;
+ };
+
/// Empty Document (does no allocation)
Document() {}
@@ -130,20 +153,29 @@ public:
*/
size_t getApproximateSize() const;
- /** Compare two documents.
+ /**
+ * Compare two documents. Most callers should prefer using DocumentComparator instead. See
+ * document_comparator.h for details.
*
* BSON document field order is significant, so this just goes through
* the fields in order. The comparison is done in roughly the same way
* as strings are compared, but comparing one field at a time instead
* of one character at a time.
*
+ * Pass a non-null StringData::ComparatorInterface if special string comparison semantics are
+ * required. If the comparator is null, then a simple binary compare is used for strings. This
+ * comparator is only used for string *values*; field names are always compared using simple
+ * binary compare.
+ *
* Note: This does not consider metadata when comparing documents.
*
* @returns an integer less than zero, zero, or an integer greater than
* zero, depending on whether lhs < rhs, lhs == rhs, or lhs > rhs
* Warning: may return values other than -1, 0, or 1
*/
- static int compare(const Document& lhs, const Document& rhs);
+ static int compare(const Document& lhs,
+ const Document& rhs,
+ const StringData::ComparatorInterface* stringComparator = nullptr);
std::string toString() const;
@@ -246,25 +278,38 @@ private:
boost::intrusive_ptr<const DocumentStorage> _storage;
};
-inline bool operator==(const Document& l, const Document& r) {
- return Document::compare(l, r) == 0;
+//
+// Comparison API.
+//
+// Document instances can be compared either using Document::compare() or via operator overloads.
+// Most callers should prefer operator overloads. Note that the operator overloads return a
+// DeferredComparison, which must be subsequently evaluated by a DocumentComparator. See
+// document_comparator.h for details.
+//
+
+inline Document::DeferredComparison operator==(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kEQ, lhs, rhs);
}
-inline bool operator!=(const Document& l, const Document& r) {
- return Document::compare(l, r) != 0;
-}
-inline bool operator<(const Document& l, const Document& r) {
- return Document::compare(l, r) < 0;
+
+inline Document::DeferredComparison operator!=(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kNE, lhs, rhs);
}
-inline bool operator<=(const Document& l, const Document& r) {
- return Document::compare(l, r) <= 0;
+
+inline Document::DeferredComparison operator<(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kLT, lhs, rhs);
}
-inline bool operator>(const Document& l, const Document& r) {
- return Document::compare(l, r) > 0;
+
+inline Document::DeferredComparison operator<=(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kLTE, lhs, rhs);
}
-inline bool operator>=(const Document& l, const Document& r) {
- return Document::compare(l, r) >= 0;
+
+inline Document::DeferredComparison operator>(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kGT, lhs, rhs);
}
+inline Document::DeferredComparison operator>=(const Document& lhs, const Document& rhs) {
+ return Document::DeferredComparison(Document::DeferredComparison::Type::kGTE, lhs, rhs);
+}
/** This class is returned by MutableDocument to allow you to modify its values.
* You are not allowed to hold variables of this type (enforced by the type system).
diff --git a/src/mongo/db/pipeline/document_comparator.cpp b/src/mongo/db/pipeline/document_comparator.cpp
new file mode 100644
index 00000000000..a11eb092122
--- /dev/null
+++ b/src/mongo/db/pipeline/document_comparator.cpp
@@ -0,0 +1,57 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/document_comparator.h"
+
+#include "mongo/util/assert_util.h"
+
+namespace mongo {
+
+bool DocumentComparator::evaluate(Document::DeferredComparison deferredComparison) const {
+ int cmp = Document::compare(deferredComparison.lhs, deferredComparison.rhs, _stringComparator);
+ switch (deferredComparison.type) {
+ case Document::DeferredComparison::Type::kLT:
+ return cmp < 0;
+ case Document::DeferredComparison::Type::kLTE:
+ return cmp <= 0;
+ case Document::DeferredComparison::Type::kEQ:
+ return cmp == 0;
+ case Document::DeferredComparison::Type::kGTE:
+ return cmp >= 0;
+ case Document::DeferredComparison::Type::kGT:
+ return cmp > 0;
+ case Document::DeferredComparison::Type::kNE:
+ return cmp != 0;
+ }
+
+ MONGO_UNREACHABLE;
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_comparator.h b/src/mongo/db/pipeline/document_comparator.h
new file mode 100644
index 00000000000..91c1e08da58
--- /dev/null
+++ b/src/mongo/db/pipeline/document_comparator.h
@@ -0,0 +1,59 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/base/string_data.h"
+#include "mongo/db/pipeline/document.h"
+
+namespace mongo {
+
+class DocumentComparator {
+public:
+ /**
+ * Constructs a document comparator with simple comparison semantics.
+ */
+ DocumentComparator() = default;
+
+ /**
+ * Constructs a document comparator with special string comparison semantics.
+ */
+ DocumentComparator(const StringData::ComparatorInterface* stringComparator)
+ : _stringComparator(stringComparator) {}
+
+ /**
+ * Evaluates a deferred comparison object that was generated by invoking one of the comparison
+ * operators on the Document class.
+ */
+ bool evaluate(Document::DeferredComparison deferredComparison) const;
+
+private:
+ const StringData::ComparatorInterface* _stringComparator = nullptr;
+};
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_comparator_test.cpp b/src/mongo/db/pipeline/document_comparator_test.cpp
new file mode 100644
index 00000000000..40229e982c1
--- /dev/null
+++ b/src/mongo/db/pipeline/document_comparator_test.cpp
@@ -0,0 +1,169 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/document_comparator.h"
+
+#include "mongo/db/query/collation/collator_interface_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+TEST(DocumentComparatorTest, EqualToEvaluatesCorrectly) {
+ const Document doc1{{"foo", "bar"}};
+ const Document doc2{{"foo", "bar"}};
+ const Document doc3{{"foo", "baz"}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 == doc2));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc1 == doc3));
+}
+
+TEST(DocumentComparatorTest, EqualToEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ const Document doc1{{"foo", "abc"}};
+ const Document doc2{{"foo", "def"}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2));
+}
+
+TEST(DocumentComparatorTest, NotEqualEvaluatesCorrectly) {
+ const Document doc1{{"foo", "bar"}};
+ const Document doc2{{"foo", "bar"}};
+ const Document doc3{{"foo", "baz"}};
+ ASSERT_FALSE(DocumentComparator().evaluate(doc1 != doc2));
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 != doc3));
+}
+
+TEST(DocumentComparatorTest, NotEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ const Document doc1{{"foo", "abc"}};
+ const Document doc2{{"foo", "def"}};
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc1 != doc2));
+}
+
+TEST(DocumentComparatorTest, LessThanEvaluatesCorrectly) {
+ const Document doc1{{"foo", "a"}};
+ const Document doc2{{"foo", "b"}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 < doc2));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc2 < doc1));
+}
+
+TEST(DocumentComparatorTest, LessThanEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ const Document doc1{{"foo", "za"}};
+ const Document doc2{{"foo", "yb"}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 < doc2));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc2 < doc1));
+}
+
+TEST(DocumentComparatorTest, LessThanOrEqualEvaluatesCorrectly) {
+ const Document doc1{{"foo", "a"}};
+ const Document doc2{{"foo", "a"}};
+ const Document doc3{{"foo", "b"}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 <= doc2));
+ ASSERT_TRUE(DocumentComparator().evaluate(doc2 <= doc1));
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 <= doc3));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc3 <= doc1));
+}
+
+TEST(DocumentComparatorTest, LessThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ const Document doc1{{"foo", "za"}};
+ const Document doc2{{"foo", "za"}};
+ const Document doc3{{"foo", "yb"}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 <= doc2));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 <= doc1));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 <= doc3));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 <= doc1));
+}
+
+TEST(DocumentComparatorTest, GreaterThanEvaluatesCorrectly) {
+ const Document doc1{{"foo", "b"}};
+ const Document doc2{{"foo", "a"}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 > doc2));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc2 > doc1));
+}
+
+TEST(DocumentComparatorTest, GreaterThanEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ const Document doc1{{"foo", "yb"}};
+ const Document doc2{{"foo", "za"}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 > doc2));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc2 > doc1));
+}
+
+TEST(DocumentComparatorTest, GreaterThanOrEqualEvaluatesCorrectly) {
+ const Document doc1{{"foo", "b"}};
+ const Document doc2{{"foo", "b"}};
+ const Document doc3{{"foo", "a"}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 >= doc2));
+ ASSERT_TRUE(DocumentComparator().evaluate(doc2 >= doc1));
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 >= doc3));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc3 >= doc1));
+}
+
+TEST(DocumentComparatorTest, GreaterThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ const Document doc1{{"foo", "yb"}};
+ const Document doc2{{"foo", "yb"}};
+ const Document doc3{{"foo", "za"}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 >= doc2));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 >= doc1));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 >= doc3));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 >= doc1));
+}
+
+TEST(DocumentComparatorTest, EqualToEvaluatesCorrectlyWithNumbers) {
+ const Document doc1{{"foo", 88}};
+ const Document doc2{{"foo", 88}};
+ const Document doc3{{"foo", 99}};
+ ASSERT_TRUE(DocumentComparator().evaluate(doc1 == doc2));
+ ASSERT_FALSE(DocumentComparator().evaluate(doc1 == doc3));
+}
+
+TEST(DocumentComparatorTest, NestedObjectEqualityRespectsCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ const Document doc1{{"foo", Document{{"foo", "abc"}}}};
+ const Document doc2{{"foo", Document{{"foo", "def"}}}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 == doc1));
+}
+
+TEST(DocumentComparatorTest, NestedArrayEqualityRespectsCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ const Document doc1{{"foo", std::vector<Value>{Value("a"), Value("b")}}};
+ const Document doc2{{"foo", std::vector<Value>{Value("c"), Value("d")}}};
+ const Document doc3{{"foo", std::vector<Value>{Value("c"), Value("d"), Value("e")}}};
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc1 == doc2));
+ ASSERT_TRUE(DocumentComparator(&collator).evaluate(doc2 == doc1));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc1 == doc3));
+ ASSERT_FALSE(DocumentComparator(&collator).evaluate(doc3 == doc1));
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_source.h b/src/mongo/db/pipeline/document_source.h
index 0ffded07edb..f6c8aa91545 100644
--- a/src/mongo/db/pipeline/document_source.h
+++ b/src/mongo/db/pipeline/document_source.h
@@ -30,6 +30,7 @@
#include "mongo/platform/basic.h"
+#include <boost/optional.hpp>
#include <deque>
#include <list>
#include <string>
@@ -52,6 +53,7 @@
#include "mongo/db/pipeline/parsed_aggregation_projection.h"
#include "mongo/db/pipeline/pipeline.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/db/query/plan_summary_stats.h"
#include "mongo/db/sorter/sorter.h"
#include "mongo/stdx/functional.h"
@@ -230,6 +232,18 @@ public:
virtual void reattachToOperationContext(OperationContext* opCtx) {}
/**
+ * Injects a new ExpressionContext into this DocumentSource and propagates the ExpressionContext
+ * to all child expressions, accumulators, etc.
+ *
+ * Stages which require work to propagate the ExpressionContext to their private execution
+ * machinery should override doInjectExpressionContext().
+ */
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {
+ pExpCtx = expCtx;
+ doInjectExpressionContext();
+ }
+
+ /**
* Create a DocumentSource pipeline stage from 'stageObj'.
*/
static std::vector<boost::intrusive_ptr<DocumentSource>> parse(
@@ -264,6 +278,15 @@ protected:
*/
explicit DocumentSource(const boost::intrusive_ptr<ExpressionContext>& pExpCtx);
+ /**
+ * DocumentSources which need to update their internal state when attaching to a new
+ * ExpressionContext should override this method.
+ *
+ * Any stage subclassing from DocumentSource should override this method if it contains
+ * expressions or accumulators which need to attach to the newly injected ExpressionContext.
+ */
+ virtual void doInjectExpressionContext() {}
+
/*
Most DocumentSources have an underlying source they get their data
from. This is a convenience for them.
@@ -491,6 +514,9 @@ public:
const PlanSummaryStats& getPlanSummaryStats() const;
+protected:
+ void doInjectExpressionContext() final;
+
private:
DocumentSourceCursor(const std::string& ns,
const std::shared_ptr<PlanExecutor>& exec,
@@ -524,7 +550,7 @@ private:
class DocumentSourceGroup final : public DocumentSource, public SplittableDocumentSource {
public:
using Accumulators = std::vector<boost::intrusive_ptr<Accumulator>>;
- using GroupsMap = std::unordered_map<Value, Accumulators, Value::Hash>;
+ using GroupsMap = ValueUnorderedMap<Accumulators>;
// Virtuals from DocumentSource.
boost::intrusive_ptr<DocumentSource> optimize() final;
@@ -573,6 +599,9 @@ public:
boost::intrusive_ptr<DocumentSource> getShardSource() final;
boost::intrusive_ptr<DocumentSource> getMergeSource() final;
+protected:
+ void doInjectExpressionContext() final;
+
private:
explicit DocumentSourceGroup(const boost::intrusive_ptr<ExpressionContext>& pExpCtx);
@@ -647,7 +676,10 @@ private:
Value _currentId;
Accumulators _currentAccumulators;
- GroupsMap groups;
+ // We use boost::optional to defer initialization until the ExpressionContext containing the
+ // correct comparator is injected, since the groups must be built using the comparator's
+ // definition of equality.
+ boost::optional<GroupsMap> _groups;
bool _spilled;
@@ -789,6 +821,8 @@ public:
const std::string& path,
boost::intrusive_ptr<ExpressionContext> expCtx);
+ void doInjectExpressionContext();
+
private:
DocumentSourceMatch(const BSONObj& query,
const boost::intrusive_ptr<ExpressionContext>& pExpCtx);
@@ -914,11 +948,17 @@ public:
return this;
}
+ void doInjectExpressionContext() override {
+ isExpCtxInjected = true;
+ }
+
// Return documents from front of queue.
std::deque<Document> queue;
+
bool isDisposed = false;
bool isDetachedFromOpCtx = false;
bool isOptimized = false;
+ bool isExpCtxInjected = false;
BSONObjSet sorts;
};
@@ -1001,6 +1041,8 @@ public:
*/
boost::intrusive_ptr<DocumentSource> optimize() final;
+ void doInjectExpressionContext() final;
+
/**
* Parse the projection from the user-supplied BSON.
*/
@@ -1028,6 +1070,8 @@ public:
Pipeline::SourceContainer::iterator optimizeAt(Pipeline::SourceContainer::iterator itr,
Pipeline::SourceContainer* container) final;
+ void doInjectExpressionContext() final;
+
static boost::intrusive_ptr<DocumentSource> createFromBson(
BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -1063,6 +1107,8 @@ public:
return _size;
}
+ void doInjectExpressionContext() final;
+
static boost::intrusive_ptr<DocumentSource> createFromBson(
BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& expCtx);
@@ -1086,6 +1132,8 @@ public:
Value serialize(bool explain = false) const final;
GetDepsReturn getDependencies(DepsTracker* deps) const final;
+ void doInjectExpressionContext() final;
+
static boost::intrusive_ptr<DocumentSourceSampleFromRandomCursor> create(
const boost::intrusive_ptr<ExpressionContext>& expCtx,
long long size,
@@ -1111,8 +1159,9 @@ private:
std::string _idField;
// Keeps track of the documents that have been returned, since a random cursor is allowed to
- // return duplicates.
- ValueSet _seenDocs;
+ // return duplicates. We use boost::optional to defer initialization until the ExpressionContext
+ // containing the correct comparator is injected.
+ boost::optional<ValueUnorderedSet> _seenDocs;
// The approximate number of documents in the collection (includes orphans).
const long long _nDocsInColl;
@@ -1188,6 +1237,7 @@ private:
long long count;
};
+// TODO SERVER-23349: Make aggregation sort respect the collation.
class DocumentSourceSort final : public DocumentSource, public SplittableDocumentSource {
public:
// virtuals from DocumentSource
@@ -1473,6 +1523,7 @@ private:
std::unique_ptr<Unwinder> _unwinder;
};
+// TODO SERVER-23349: Make geoNear agg stage respect the collation.
class DocumentSourceGeoNear : public DocumentSourceNeedsMongod, public SplittableDocumentSource {
public:
static const long long kDefaultLimit;
@@ -1543,6 +1594,8 @@ private:
/**
* Queries separate collection for equality matches with documents in the pipeline collection.
* Adds matching documents to a new array field in the input document.
+ *
+ * TODO SERVER-23349: Make $lookup respect the collation.
*/
class DocumentSourceLookUp final : public DocumentSourceNeedsMongod,
public SplittableDocumentSource {
@@ -1620,6 +1673,7 @@ private:
boost::optional<Document> _input;
};
+// TODO SERVER-23349: Make $graphLookup respect the collation.
class DocumentSourceGraphLookUp final : public DocumentSourceNeedsMongod {
public:
boost::optional<Document> getNext() final;
@@ -1647,6 +1701,8 @@ public:
collections->push_back(_from);
}
+ void doInjectExpressionContext() final;
+
static boost::intrusive_ptr<DocumentSource> createFromBson(
BSONElement elem, const boost::intrusive_ptr<ExpressionContext>& pExpCtx);
@@ -1698,7 +1754,7 @@ private:
* Updates '_cache' with 'result' appropriately, given that 'result' was retrieved when querying
* for 'queried'.
*/
- void addToCache(const BSONObj& result, const unordered_set<Value, Value::Hash>& queried);
+ void addToCache(const BSONObj& result, const ValueUnorderedSet& queried);
/**
* Assert that '_visited' and '_frontier' have not exceeded the maximum meory usage, and then
@@ -1731,11 +1787,15 @@ private:
size_t _frontierUsageBytes = 0;
// Only used during the breadth-first search, tracks the set of values on the current frontier.
- std::unordered_set<Value, Value::Hash> _frontier;
+ // We use boost::optional to defer initialization until the ExpressionContext containing the
+ // correct comparator is injected.
+ boost::optional<ValueUnorderedSet> _frontier;
// Tracks nodes that have been discovered for a given input. Keys are the '_id' value of the
- // document from the foreign collection, value is the document itself.
- std::unordered_map<Value, BSONObj, Value::Hash> _visited;
+ // document from the foreign collection, value is the document itself. We use boost::optional
+ // to defer initialization until the ExpressionContext containing the correct comparator is
+ // injected.
+ boost::optional<ValueUnorderedMap<BSONObj>> _visited;
// Caches query results to avoid repeating any work. This structure is maintained across calls
// to getNext().
diff --git a/src/mongo/db/pipeline/document_source_bucket.cpp b/src/mongo/db/pipeline/document_source_bucket.cpp
index 6018402113b..d895b532156 100644
--- a/src/mongo/db/pipeline/document_source_bucket.cpp
+++ b/src/mongo/db/pipeline/document_source_bucket.cpp
@@ -121,6 +121,8 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson(
<< typeName(upper.getType())
<< ".",
lowerCanonicalType == upperCanonicalType);
+ // TODO SERVER-25038: This check must be deferred so that it respects the final
+ // collator, which is not necessarily the same as the collator at parse time.
uassert(40194,
str::stream()
<< "The 'boundaries' option to $bucket must be sorted, but elements "
@@ -132,7 +134,7 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson(
<< " is not less than "
<< upper.toString()
<< ").",
- lower < upper);
+ pExpCtx->getValueComparator().evaluate(lower < upper));
}
} else if ("default" == argName) {
// If there is a default, make sure that it parses to a constant expression then add
@@ -173,11 +175,15 @@ vector<intrusive_ptr<DocumentSource>> DocumentSourceBucket::createFromBson(
Value upperValue = boundaryValues.back();
if (canonicalizeBSONType(defaultValue.getType()) ==
canonicalizeBSONType(lowerValue.getType())) {
- // If the default has the same canonical type as the bucket's boundaries, then make sure
- // the default is less than the lowest boundary or greater than or equal to the highest
+ // If the default has the same canonical type as the bucket's boundaries, then make sure the
+ // default is less than the lowest boundary or greater than or equal to the highest
// boundary.
- const bool hasValidDefault =
- defaultValue < lowerValue || Value::compare(defaultValue, upperValue) >= 0;
+ //
+ // TODO SERVER-25038: This check must be deferred so that it respects the final collator,
+ // which is not necessarily the same as the collator at parse time.
+ const auto& valueCmp = pExpCtx->getValueComparator();
+ const bool hasValidDefault = valueCmp.evaluate(defaultValue < lowerValue) ||
+ valueCmp.evaluate(defaultValue >= upperValue);
uassert(40199,
"The $bucket 'default' field must be less than the lowest boundary or greater than "
"or equal to the highest boundary.",
diff --git a/src/mongo/db/pipeline/document_source_cursor.cpp b/src/mongo/db/pipeline/document_source_cursor.cpp
index c9a0d0673e3..aef76129244 100644
--- a/src/mongo/db/pipeline/document_source_cursor.cpp
+++ b/src/mongo/db/pipeline/document_source_cursor.cpp
@@ -221,6 +221,12 @@ Value DocumentSourceCursor::serialize(bool explain) const {
return Value(DOC(getSourceName() << out.freezeToValue()));
}
+void DocumentSourceCursor::doInjectExpressionContext() {
+ if (_limit) {
+ _limit->injectExpressionContext(pExpCtx);
+ }
+}
+
DocumentSourceCursor::DocumentSourceCursor(const string& ns,
const std::shared_ptr<PlanExecutor>& exec,
const intrusive_ptr<ExpressionContext>& pCtx)
@@ -239,7 +245,9 @@ intrusive_ptr<DocumentSourceCursor> DocumentSourceCursor::create(
const string& ns,
const std::shared_ptr<PlanExecutor>& exec,
const intrusive_ptr<ExpressionContext>& pExpCtx) {
- return new DocumentSourceCursor(ns, exec, pExpCtx);
+ intrusive_ptr<DocumentSourceCursor> source(new DocumentSourceCursor(ns, exec, pExpCtx));
+ source->injectExpressionContext(pExpCtx);
+ return source;
}
void DocumentSourceCursor::setProjection(const BSONObj& projection,
diff --git a/src/mongo/db/pipeline/document_source_facet.cpp b/src/mongo/db/pipeline/document_source_facet.cpp
index 45c24a1480b..0c3e2981d79 100644
--- a/src/mongo/db/pipeline/document_source_facet.cpp
+++ b/src/mongo/db/pipeline/document_source_facet.cpp
@@ -122,6 +122,12 @@ intrusive_ptr<DocumentSource> DocumentSourceFacet::optimize() {
return this;
}
+void DocumentSourceFacet::doInjectExpressionContext() {
+ for (auto&& facet : _facetPipelines) {
+ facet.second->injectExpressionContext(pExpCtx);
+ }
+}
+
void DocumentSourceFacet::doInjectMongodInterface(std::shared_ptr<MongodInterface> mongod) {
for (auto&& facet : _facetPipelines) {
for (auto&& stage : facet.second->getSources()) {
diff --git a/src/mongo/db/pipeline/document_source_facet.h b/src/mongo/db/pipeline/document_source_facet.h
index 583c09ea4b5..a6ac631c327 100644
--- a/src/mongo/db/pipeline/document_source_facet.h
+++ b/src/mongo/db/pipeline/document_source_facet.h
@@ -77,6 +77,11 @@ public:
boost::intrusive_ptr<DocumentSource> optimize() final;
/**
+ * Injects the expression context into inner pipelines.
+ */
+ void doInjectExpressionContext() final;
+
+ /**
* Takes a union of all sub-pipelines, and adds them to 'deps'.
*/
GetDepsReturn getDependencies(DepsTracker* deps) const final;
diff --git a/src/mongo/db/pipeline/document_source_facet_test.cpp b/src/mongo/db/pipeline/document_source_facet_test.cpp
index 7deb8127949..519b8d0a222 100644
--- a/src/mongo/db/pipeline/document_source_facet_test.cpp
+++ b/src/mongo/db/pipeline/document_source_facet_test.cpp
@@ -38,6 +38,7 @@
#include "mongo/bson/json.h"
#include "mongo/db/pipeline/aggregation_context_fixture.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/util/assert_util.h"
namespace mongo {
@@ -194,7 +195,7 @@ TEST_F(DocumentSourceFacetTest, SingleFacetShouldReceiveAllDocuments) {
auto output = facetStage->getNext();
ASSERT_TRUE(output);
- ASSERT_EQ(*output, Document(fromjson("{results: [{_id: 0}, {_id: 1}]}")));
+ ASSERT_DOCUMENT_EQ(*output, Document(fromjson("{results: [{_id: 0}, {_id: 1}]}")));
// Should be exhausted now.
ASSERT_FALSE(facetStage->getNext());
@@ -224,8 +225,8 @@ TEST_F(DocumentSourceFacetTest, MultipleFacetsShouldSeeTheSameDocuments) {
std::vector<Value> expectedOutputs(inputs.begin(), inputs.end());
ASSERT_TRUE(output);
ASSERT_EQ((*output).size(), 2UL);
- ASSERT_EQ((*output)["first"], Value(expectedOutputs));
- ASSERT_EQ((*output)["second"], Value(expectedOutputs));
+ ASSERT_VALUE_EQ((*output)["first"], Value(expectedOutputs));
+ ASSERT_VALUE_EQ((*output)["second"], Value(expectedOutputs));
// Should be exhausted now.
ASSERT_FALSE(facetStage->getNext());
@@ -248,7 +249,7 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToEvaluateMultipleStagesWithinOneSub
auto output = facetStage->getNext();
ASSERT_TRUE(output);
- ASSERT_EQ(*output, Document(fromjson("{subPipe: [{_id: 0}, {_id: 1}]}")));
+ ASSERT_DOCUMENT_EQ(*output, Document(fromjson("{subPipe: [{_id: 0}, {_id: 1}]}")));
}
//
@@ -287,10 +288,10 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToReParseSerializedStage) {
// Should have two fields: "skippedOne" and "skippedTwo".
auto serializedStage = serialization[0].getDocument()["$facet"].getDocument();
ASSERT_EQ(serializedStage.size(), 2UL);
- ASSERT_EQ(serializedStage["skippedOne"],
- Value(std::vector<Value>{Value(Document{{"$skip", 1}})}));
- ASSERT_EQ(serializedStage["skippedTwo"],
- Value(std::vector<Value>{Value(Document{{"$skip", 2}})}));
+ ASSERT_VALUE_EQ(serializedStage["skippedOne"],
+ Value(std::vector<Value>{Value(Document{{"$skip", 1}})}));
+ ASSERT_VALUE_EQ(serializedStage["skippedTwo"],
+ Value(std::vector<Value>{Value(Document{{"$skip", 2}})}));
auto serializedBson = serialization[0].getDocument().toBson();
auto roundTripped = DocumentSourceFacet::createFromBson(serializedBson.firstElement(), ctx);
@@ -300,7 +301,7 @@ TEST_F(DocumentSourceFacetTest, ShouldBeAbleToReParseSerializedStage) {
roundTripped->serializeToArray(newSerialization);
ASSERT_EQ(newSerialization.size(), 1UL);
- ASSERT_EQ(newSerialization[0], serialization[0]);
+ ASSERT_VALUE_EQ(newSerialization[0], serialization[0]);
}
TEST_F(DocumentSourceFacetTest, ShouldOptimizeInnerPipelines) {
diff --git a/src/mongo/db/pipeline/document_source_geo_near.cpp b/src/mongo/db/pipeline/document_source_geo_near.cpp
index 5f4e6ae0552..0107b41087c 100644
--- a/src/mongo/db/pipeline/document_source_geo_near.cpp
+++ b/src/mongo/db/pipeline/document_source_geo_near.cpp
@@ -165,7 +165,9 @@ void DocumentSourceGeoNear::runCommand() {
intrusive_ptr<DocumentSourceGeoNear> DocumentSourceGeoNear::create(
const intrusive_ptr<ExpressionContext>& pCtx) {
- return new DocumentSourceGeoNear(pCtx);
+ intrusive_ptr<DocumentSourceGeoNear> source(new DocumentSourceGeoNear(pCtx));
+ source->injectExpressionContext(pCtx);
+ return source;
}
intrusive_ptr<DocumentSource> DocumentSourceGeoNear::createFromBson(
diff --git a/src/mongo/db/pipeline/document_source_graph_lookup.cpp b/src/mongo/db/pipeline/document_source_graph_lookup.cpp
index 0c5ce1d4c21..01887184512 100644
--- a/src/mongo/db/pipeline/document_source_graph_lookup.cpp
+++ b/src/mongo/db/pipeline/document_source_graph_lookup.cpp
@@ -73,11 +73,11 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNext() {
performSearch();
std::vector<Value> results;
- while (!_visited.empty()) {
+ while (!_visited->empty()) {
// Remove elements one at a time to avoid consuming more memory.
- auto it = _visited.begin();
+ auto it = _visited->begin();
results.push_back(Value(it->second));
- _visited.erase(it);
+ _visited->erase(it);
}
MutableDocument output(*_input);
@@ -85,7 +85,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNext() {
_visitedUsageBytes = 0;
- invariant(_visited.empty());
+ invariant(_visited->empty());
return output.freeze();
}
@@ -96,7 +96,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() {
// If the unwind is not preserving empty arrays, we might have to process multiple inputs before
// we get one that will produce an output.
while (true) {
- if (_visited.empty()) {
+ if (_visited->empty()) {
// No results are left for the current input, so we should move on to the next one and
// perform a new search.
if (!(_input = pSource->getNext())) {
@@ -110,7 +110,7 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() {
}
MutableDocument unwound(*_input);
- if (_visited.empty()) {
+ if (_visited->empty()) {
if ((*_unwind)->preserveNullAndEmptyArrays()) {
// Since "preserveNullAndEmptyArrays" was specified, output a document even though
// we had no result.
@@ -124,13 +124,13 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() {
continue;
}
} else {
- auto it = _visited.begin();
+ auto it = _visited->begin();
unwound.setNestedField(_as, Value(it->second));
if (indexPath) {
unwound.setNestedField(*indexPath, Value(_outputIndex));
++_outputIndex;
}
- _visited.erase(it);
+ _visited->erase(it);
}
return unwound.freeze();
@@ -139,8 +139,8 @@ boost::optional<Document> DocumentSourceGraphLookUp::getNextUnwound() {
void DocumentSourceGraphLookUp::dispose() {
_cache.clear();
- _frontier.clear();
- _visited.clear();
+ _frontier->clear();
+ _visited->clear();
pSource->dispose();
}
@@ -154,8 +154,8 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() {
BSONObjSet cached;
auto query = constructQuery(&cached);
- std::unordered_set<Value, Value::Hash> queried;
- _frontier.swap(queried);
+ ValueUnorderedSet queried = pExpCtx->getValueComparator().makeUnorderedValueSet();
+ _frontier->swap(queried);
_frontierUsageBytes = 0;
// Process cached values, populating '_frontier' for the next iteration of search.
@@ -186,7 +186,7 @@ void DocumentSourceGraphLookUp::doBreadthFirstSearch() {
} while (shouldPerformAnotherQuery && depth < std::numeric_limits<long long>::max() &&
(!_maxDepth || depth <= *_maxDepth));
- _frontier.clear();
+ _frontier->clear();
_frontierUsageBytes = 0;
}
@@ -204,7 +204,7 @@ BSONObj addDepthFieldToObject(const std::string& field, long long depth, BSONObj
bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long long depth) {
Value _id = Value(result.getField("_id"));
- if (_visited.find(_id) != _visited.end()) {
+ if (_visited->find(_id) != _visited->end()) {
// We've already seen this object, don't repeat any work.
return false;
}
@@ -215,7 +215,7 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon
_depthField ? addDepthFieldToObject(_depthField->fullPath(), depth, result) : result;
// Add the object to our '_visited' list.
- _visited[_id] = fullObject;
+ (*_visited)[_id] = fullObject;
// Update the size of '_visited' appropriately.
_visitedUsageBytes += _id.getApproximateSize();
@@ -231,12 +231,12 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon
Value recurseOn = Value(elem);
if (recurseOn.isArray()) {
for (auto&& subElem : recurseOn.getArray()) {
- _frontier.insert(subElem);
+ _frontier->insert(subElem);
_frontierUsageBytes += subElem.getApproximateSize();
}
} else if (!recurseOn.missing()) {
// Don't recurse on a missing value.
- _frontier.insert(recurseOn);
+ _frontier->insert(recurseOn);
_frontierUsageBytes += recurseOn.getApproximateSize();
}
}
@@ -246,7 +246,7 @@ bool DocumentSourceGraphLookUp::addToVisitedAndFrontier(BSONObj result, long lon
}
void DocumentSourceGraphLookUp::addToCache(const BSONObj& result,
- const unordered_set<Value, Value::Hash>& queried) {
+ const ValueUnorderedSet& queried) {
BSONElementSet cacheByValues;
dps::extractAllElementsAlongPath(result, _connectToField.fullPath(), cacheByValues);
@@ -271,13 +271,13 @@ void DocumentSourceGraphLookUp::addToCache(const BSONObj& result,
boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* cached) {
// Add any cached values to 'cached' and remove them from '_frontier'.
- for (auto it = _frontier.begin(); it != _frontier.end();) {
+ for (auto it = _frontier->begin(); it != _frontier->end();) {
if (auto entry = _cache[*it]) {
for (auto&& obj : *entry) {
cached->insert(obj);
}
size_t valueSize = it->getApproximateSize();
- it = _frontier.erase(it);
+ it = _frontier->erase(it);
// If the cached value increased in size while in the cache, we don't want to underflow
// '_frontierUsageBytes'.
@@ -302,7 +302,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* c
BSONObjBuilder subObj(connectToObj.subobjStart(_connectToField.fullPath()));
{
BSONArrayBuilder in(subObj.subarrayStart("$in"));
- for (auto&& value : _frontier) {
+ for (auto&& value : *_frontier) {
in << value;
}
}
@@ -310,7 +310,7 @@ boost::optional<BSONObj> DocumentSourceGraphLookUp::constructQuery(BSONObjSet* c
}
}
- return _frontier.empty() ? boost::none : boost::optional<BSONObj>(query.obj());
+ return _frontier->empty() ? boost::none : boost::optional<BSONObj>(query.obj());
}
void DocumentSourceGraphLookUp::performSearch() {
@@ -324,11 +324,11 @@ void DocumentSourceGraphLookUp::performSearch() {
// If _startWith evaluates to an array, treat each value as a separate starting point.
if (startingValue.isArray()) {
for (auto value : startingValue.getArray()) {
- _frontier.insert(value);
+ _frontier->insert(value);
_frontierUsageBytes += value.getApproximateSize();
}
} else {
- _frontier.insert(startingValue);
+ _frontier->insert(startingValue);
_frontierUsageBytes += startingValue.getApproximateSize();
}
@@ -410,6 +410,11 @@ void DocumentSourceGraphLookUp::serializeToArray(std::vector<Value>& array, bool
}
}
+void DocumentSourceGraphLookUp::doInjectExpressionContext() {
+ _frontier = pExpCtx->getValueComparator().makeUnorderedValueSet();
+ _visited = pExpCtx->getValueComparator().makeUnorderedValueMap<BSONObj>();
+}
+
DocumentSourceGraphLookUp::DocumentSourceGraphLookUp(
NamespaceString from,
std::string as,
diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp
index a356986671e..ed7ab95e575 100644
--- a/src/mongo/db/pipeline/document_source_group.cpp
+++ b/src/mongo/db/pipeline/document_source_group.cpp
@@ -75,7 +75,7 @@ boost::optional<Document> DocumentSourceGroup::getNextSpilled() {
_currentId = _firstPartOfNextGroup.first;
const size_t numAccumulators = vpAccumulatorFactory.size();
- while (_currentId == _firstPartOfNextGroup.first) {
+ while (pExpCtx->getValueComparator().evaluate(_currentId == _firstPartOfNextGroup.first)) {
// Inside of this loop, _firstPartOfNextGroup is the current data being processed.
// At loop exit, it is the first value to be processed in the next group.
switch (numAccumulators) { // mirrors switch in spill()
@@ -104,12 +104,12 @@ boost::optional<Document> DocumentSourceGroup::getNextSpilled() {
boost::optional<Document> DocumentSourceGroup::getNextStandard() {
// Not spilled, and not streaming.
- if (groups.empty())
+ if (_groups->empty())
return boost::none;
Document out = makeDocument(groupsIterator->first, groupsIterator->second, pExpCtx->inShard);
- if (++groupsIterator == groups.end())
+ if (++groupsIterator == _groups->end())
dispose();
return out;
@@ -146,7 +146,7 @@ boost::optional<Document> DocumentSourceGroup::getNextStreaming() {
// Compute the id. If it does not match _currentId, we will exit the loop, leaving
// _firstDocOfNextGroup set for the next time getNext() is called.
id = computeId(_variables.get());
- } while (_currentId == id);
+ } while (pExpCtx->getValueComparator().evaluate(_currentId == id));
Document out = makeDocument(_currentId, _currentAccumulators, pExpCtx->inShard);
_currentId = std::move(id);
@@ -156,11 +156,11 @@ boost::optional<Document> DocumentSourceGroup::getNextStreaming() {
void DocumentSourceGroup::dispose() {
// Free our resources.
- GroupsMap().swap(groups);
+ GroupsMap().swap(*_groups);
_sorterIterator.reset();
// Make us look done.
- groupsIterator = groups.end();
+ groupsIterator = _groups->end();
_firstDocOfNextGroup = boost::none;
@@ -183,6 +183,23 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::optimize() {
return this;
}
+void DocumentSourceGroup::doInjectExpressionContext() {
+ // Groups map must respect new comparator.
+ _groups = pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>();
+
+ for (auto&& idExpr : _idExpressions) {
+ idExpr->injectExpressionContext(pExpCtx);
+ }
+
+ for (auto&& expr : vpExpression) {
+ expr->injectExpressionContext(pExpCtx);
+ }
+
+ for (auto&& accum : _currentAccumulators) {
+ accum->injectExpressionContext(pExpCtx);
+ }
+}
+
Value DocumentSourceGroup::serialize(bool explain) const {
MutableDocument insides;
@@ -237,7 +254,9 @@ DocumentSource::GetDepsReturn DocumentSourceGroup::getDependencies(DepsTracker*
intrusive_ptr<DocumentSourceGroup> DocumentSourceGroup::create(
const intrusive_ptr<ExpressionContext>& pExpCtx) {
- return new DocumentSourceGroup(pExpCtx);
+ intrusive_ptr<DocumentSourceGroup> source(new DocumentSourceGroup(pExpCtx));
+ source->injectExpressionContext(pExpCtx);
+ return source;
}
DocumentSourceGroup::DocumentSourceGroup(const intrusive_ptr<ExpressionContext>& pExpCtx)
@@ -435,6 +454,7 @@ void DocumentSourceGroup::initialize() {
_currentAccumulators.reserve(numAccumulators);
for (size_t i = 0; i < numAccumulators; i++) {
_currentAccumulators.push_back(vpAccumulatorFactory[i]());
+ _currentAccumulators.back()->injectExpressionContext(pExpCtx);
}
// We only need to load the first document.
@@ -477,9 +497,9 @@ void DocumentSourceGroup::initialize() {
Look for the _id value in the map; if it's not there, add a
new entry with a blank accumulator.
*/
- const size_t oldSize = groups.size();
- vector<intrusive_ptr<Accumulator>>& group = groups[id];
- const bool inserted = groups.size() != oldSize;
+ const size_t oldSize = _groups->size();
+ vector<intrusive_ptr<Accumulator>>& group = (*_groups)[id];
+ const bool inserted = _groups->size() != oldSize;
if (inserted) {
memoryUsageBytes += id.getApproximateSize();
@@ -488,6 +508,7 @@ void DocumentSourceGroup::initialize() {
group.reserve(numAccumulators);
for (size_t i = 0; i < numAccumulators; i++) {
group.push_back(vpAccumulatorFactory[i]());
+ group.back()->injectExpressionContext(pExpCtx);
}
} else {
for (size_t i = 0; i < numAccumulators; i++) {
@@ -524,12 +545,12 @@ void DocumentSourceGroup::initialize() {
// These blocks do any final steps necessary to prepare to output results.
if (!sortedFiles.empty()) {
_spilled = true;
- if (!groups.empty()) {
+ if (!_groups->empty()) {
sortedFiles.push_back(spill());
}
// We won't be using groups again so free its memory.
- GroupsMap().swap(groups);
+ GroupsMap().swap(*_groups);
_sorterIterator.reset(
Sorter<Value, Value>::Iterator::merge(sortedFiles, SortOptions(), SorterComparator()));
@@ -538,20 +559,21 @@ void DocumentSourceGroup::initialize() {
_currentAccumulators.reserve(numAccumulators);
for (size_t i = 0; i < numAccumulators; i++) {
_currentAccumulators.push_back(vpAccumulatorFactory[i]());
+ _currentAccumulators.back()->injectExpressionContext(pExpCtx);
}
verify(_sorterIterator->more()); // we put data in, we should get something out.
_firstPartOfNextGroup = _sorterIterator->next();
} else {
// start the group iterator
- groupsIterator = groups.begin();
+ groupsIterator = _groups->begin();
}
}
shared_ptr<Sorter<Value, Value>::Iterator> DocumentSourceGroup::spill() {
vector<const GroupsMap::value_type*> ptrs; // using pointers to speed sorting
- ptrs.reserve(groups.size());
- for (GroupsMap::const_iterator it = groups.begin(), end = groups.end(); it != end; ++it) {
+ ptrs.reserve(_groups->size());
+ for (GroupsMap::const_iterator it = _groups->begin(), end = _groups->end(); it != end; ++it) {
ptrs.push_back(&*it);
}
@@ -583,7 +605,7 @@ shared_ptr<Sorter<Value, Value>::Iterator> DocumentSourceGroup::spill() {
break;
}
- groups.clear();
+ _groups->clear();
return shared_ptr<Sorter<Value, Value>::Iterator>(writer.done());
}
@@ -761,7 +783,7 @@ void DocumentSourceGroup::parseIdExpression(BSONElement groupField,
_idExpressions.push_back(ExpressionFieldPath::parse(groupField.str(), vps));
} else {
// constant id - single group
- _idExpressions.push_back(ExpressionConstant::create(Value(groupField)));
+ _idExpressions.push_back(ExpressionConstant::create(pExpCtx, Value(groupField)));
}
}
diff --git a/src/mongo/db/pipeline/document_source_limit.cpp b/src/mongo/db/pipeline/document_source_limit.cpp
index 401afc945ff..915a28ce8b6 100644
--- a/src/mongo/db/pipeline/document_source_limit.cpp
+++ b/src/mongo/db/pipeline/document_source_limit.cpp
@@ -81,7 +81,9 @@ Value DocumentSourceLimit::serialize(bool explain) const {
intrusive_ptr<DocumentSourceLimit> DocumentSourceLimit::create(
const intrusive_ptr<ExpressionContext>& pExpCtx, long long limit) {
uassert(15958, "the limit must be positive", limit > 0);
- return new DocumentSourceLimit(pExpCtx, limit);
+ intrusive_ptr<DocumentSourceLimit> source(new DocumentSourceLimit(pExpCtx, limit));
+ source->injectExpressionContext(pExpCtx);
+ return source;
}
intrusive_ptr<DocumentSource> DocumentSourceLimit::createFromBson(
diff --git a/src/mongo/db/pipeline/document_source_match.cpp b/src/mongo/db/pipeline/document_source_match.cpp
index 6067acb3597..8035c237eae 100644
--- a/src/mongo/db/pipeline/document_source_match.cpp
+++ b/src/mongo/db/pipeline/document_source_match.cpp
@@ -357,9 +357,8 @@ bool DocumentSourceMatch::isTextQuery(const BSONObj& query) {
void DocumentSourceMatch::joinMatchWith(intrusive_ptr<DocumentSourceMatch> other) {
_predicate = BSON("$and" << BSON_ARRAY(_predicate << other->getQuery()));
- // TODO SERVER-23349: Pass the appropriate CollatorInterface* instead of nullptr.
StatusWithMatchExpression status = uassertStatusOK(
- MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), nullptr));
+ MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), pExpCtx->getCollator()));
_expression = std::move(status.getValue());
}
@@ -486,14 +485,18 @@ void DocumentSourceMatch::addDependencies(DepsTracker* deps) const {
});
}
+void DocumentSourceMatch::doInjectExpressionContext() {
+ _expression->setCollator(pExpCtx->getCollator());
+}
+
DocumentSourceMatch::DocumentSourceMatch(const BSONObj& query,
const intrusive_ptr<ExpressionContext>& pExpCtx)
: DocumentSource(pExpCtx), _predicate(query.getOwned()), _isTextQuery(isTextQuery(query)) {
- // TODO SERVER-23349: Pass the appropriate CollatorInterface* instead of nullptr.
StatusWithMatchExpression status = uassertStatusOK(
- MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), nullptr));
+ MatchExpressionParser::parse(_predicate, ExtensionsCallbackNoop(), pExpCtx->getCollator()));
_expression = std::move(status.getValue());
getDependencies(&_dependencies);
}
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_source_merge_cursors.cpp b/src/mongo/db/pipeline/document_source_merge_cursors.cpp
index 01f11cb0c9f..8e7e3babbcd 100644
--- a/src/mongo/db/pipeline/document_source_merge_cursors.cpp
+++ b/src/mongo/db/pipeline/document_source_merge_cursors.cpp
@@ -52,7 +52,10 @@ const char* DocumentSourceMergeCursors::getSourceName() const {
intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::create(
std::vector<CursorDescriptor> cursorDescriptors,
const intrusive_ptr<ExpressionContext>& pExpCtx) {
- return new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx);
+ intrusive_ptr<DocumentSourceMergeCursors> source(
+ new DocumentSourceMergeCursors(std::move(cursorDescriptors), pExpCtx));
+ source->injectExpressionContext(pExpCtx);
+ return source;
}
intrusive_ptr<DocumentSource> DocumentSourceMergeCursors::createFromBson(
diff --git a/src/mongo/db/pipeline/document_source_project.cpp b/src/mongo/db/pipeline/document_source_project.cpp
index f6b70b840bd..5ad0876a7b8 100644
--- a/src/mongo/db/pipeline/document_source_project.cpp
+++ b/src/mongo/db/pipeline/document_source_project.cpp
@@ -116,4 +116,9 @@ DocumentSource::GetDepsReturn DocumentSourceProject::getDependencies(DepsTracker
return SEE_NEXT;
}
}
+
+void DocumentSourceProject::doInjectExpressionContext() {
+ _parsedProject->injectExpressionContext(pExpCtx);
}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_source_redact.cpp b/src/mongo/db/pipeline/document_source_redact.cpp
index 38f1e9fa8b8..3155037aa20 100644
--- a/src/mongo/db/pipeline/document_source_redact.cpp
+++ b/src/mongo/db/pipeline/document_source_redact.cpp
@@ -127,11 +127,12 @@ Value DocumentSourceRedact::redactValue(const Value& in) {
boost::optional<Document> DocumentSourceRedact::redactObject() {
const Value expressionResult = _expression->evaluate(_variables.get());
- if (expressionResult == keepVal) {
+ ValueComparator simpleValueCmp;
+ if (simpleValueCmp.evaluate(expressionResult == keepVal)) {
return _variables->getDocument(_currentId);
- } else if (expressionResult == pruneVal) {
+ } else if (simpleValueCmp.evaluate(expressionResult == pruneVal)) {
return boost::optional<Document>();
- } else if (expressionResult == descendVal) {
+ } else if (simpleValueCmp.evaluate(expressionResult == descendVal)) {
const Document in = _variables->getDocument(_currentId);
MutableDocument out;
out.copyMetaDataFrom(in);
@@ -160,6 +161,10 @@ intrusive_ptr<DocumentSource> DocumentSourceRedact::optimize() {
return this;
}
+void DocumentSourceRedact::doInjectExpressionContext() {
+ _expression->injectExpressionContext(pExpCtx);
+}
+
Value DocumentSourceRedact::serialize(bool explain) const {
return Value(DOC(getSourceName() << _expression.get()->serialize(explain)));
}
diff --git a/src/mongo/db/pipeline/document_source_sample.cpp b/src/mongo/db/pipeline/document_source_sample.cpp
index 6a7f96f5ed5..121b6e04d36 100644
--- a/src/mongo/db/pipeline/document_source_sample.cpp
+++ b/src/mongo/db/pipeline/document_source_sample.cpp
@@ -73,6 +73,10 @@ Value DocumentSourceSample::serialize(bool explain) const {
return Value(DOC(getSourceName() << DOC("size" << _size)));
}
+void DocumentSourceSample::doInjectExpressionContext() {
+ _sortStage->injectExpressionContext(pExpCtx);
+}
+
namespace {
const BSONObj randSortSpec = BSON("$rand" << BSON("$meta"
<< "randVal"));
diff --git a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp
index dafae4ed111..5a80a530734 100644
--- a/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp
+++ b/src/mongo/db/pipeline/document_source_sample_from_random_cursor.cpp
@@ -76,7 +76,7 @@ double smallestFromSampleOfUniform(PseudoRandom* prng, size_t N) {
boost::optional<Document> DocumentSourceSampleFromRandomCursor::getNext() {
pExpCtx->checkForInterrupt();
- if (_seenDocs.size() >= static_cast<size_t>(_size))
+ if (_seenDocs->size() >= static_cast<size_t>(_size))
return {};
auto doc = getNextNonDuplicateDocument();
@@ -114,7 +114,7 @@ boost::optional<Document> DocumentSourceSampleFromRandomCursor::getNextNonDuplic
<< (*doc).toString(),
!idField.missing());
- if (_seenDocs.insert(std::move(idField)).second) {
+ if (_seenDocs->insert(std::move(idField)).second) {
return doc;
}
LOG(1) << "$sample encountered duplicate document: " << (*doc).toString() << std::endl;
@@ -136,11 +136,18 @@ DocumentSource::GetDepsReturn DocumentSourceSampleFromRandomCursor::getDependenc
return SEE_NEXT;
}
+void DocumentSourceSampleFromRandomCursor::doInjectExpressionContext() {
+ _seenDocs = pExpCtx->getValueComparator().makeUnorderedValueSet();
+}
+
intrusive_ptr<DocumentSourceSampleFromRandomCursor> DocumentSourceSampleFromRandomCursor::create(
const intrusive_ptr<ExpressionContext>& expCtx,
long long size,
std::string idField,
long long nDocsInCollection) {
- return new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection);
+ intrusive_ptr<DocumentSourceSampleFromRandomCursor> source(
+ new DocumentSourceSampleFromRandomCursor(expCtx, size, idField, nDocsInCollection));
+ source->injectExpressionContext(expCtx);
+ return source;
}
} // mongo
diff --git a/src/mongo/db/pipeline/document_source_skip.cpp b/src/mongo/db/pipeline/document_source_skip.cpp
index 4b3d4ec168d..4957042f8a9 100644
--- a/src/mongo/db/pipeline/document_source_skip.cpp
+++ b/src/mongo/db/pipeline/document_source_skip.cpp
@@ -94,6 +94,7 @@ Pipeline::SourceContainer::iterator DocumentSourceSkip::optimizeAt(
intrusive_ptr<DocumentSourceSkip> DocumentSourceSkip::create(
const intrusive_ptr<ExpressionContext>& pExpCtx) {
intrusive_ptr<DocumentSourceSkip> pSource(new DocumentSourceSkip(pExpCtx));
+ pSource->injectExpressionContext(pExpCtx);
return pSource;
}
diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp
index a1456d7c70a..345383750cc 100644
--- a/src/mongo/db/pipeline/document_source_sort.cpp
+++ b/src/mongo/db/pipeline/document_source_sort.cpp
@@ -170,6 +170,7 @@ intrusive_ptr<DocumentSource> DocumentSourceSort::createFromBson(
intrusive_ptr<DocumentSourceSort> DocumentSourceSort::create(
const intrusive_ptr<ExpressionContext>& pExpCtx, BSONObj sortOrder, long long limit) {
intrusive_ptr<DocumentSourceSort> pSort = new DocumentSourceSort(pExpCtx);
+ pSort->injectExpressionContext(pExpCtx);
pSort->_sort = sortOrder.getOwned();
/* check for then iterate over the sort object */
diff --git a/src/mongo/db/pipeline/document_source_test.cpp b/src/mongo/db/pipeline/document_source_test.cpp
index b9c7210df33..cf13e1589b1 100644
--- a/src/mongo/db/pipeline/document_source_test.cpp
+++ b/src/mongo/db/pipeline/document_source_test.cpp
@@ -33,6 +33,7 @@
#include "mongo/db/operation_context_noop.h"
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document_source.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/pipeline.h"
#include "mongo/db/service_context.h"
@@ -230,27 +231,27 @@ private:
TEST(Mock, OneDoc) {
auto doc = DOC("a" << 1);
auto source = DocumentSourceMock::create(doc);
- ASSERT_EQ(*source->getNext(), doc);
+ ASSERT_DOCUMENT_EQ(*source->getNext(), doc);
ASSERT(!source->getNext());
}
TEST(Mock, DequeDocuments) {
auto source = DocumentSourceMock::create({DOC("a" << 1), DOC("a" << 2)});
- ASSERT_EQ(*source->getNext(), DOC("a" << 1));
- ASSERT_EQ(*source->getNext(), DOC("a" << 2));
+ ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1));
+ ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 2));
ASSERT(!source->getNext());
}
TEST(Mock, StringJSON) {
auto source = DocumentSourceMock::create("{a : 1}");
- ASSERT_EQ(*source->getNext(), DOC("a" << 1));
+ ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1));
ASSERT(!source->getNext());
}
TEST(Mock, DequeStringJSONs) {
auto source = DocumentSourceMock::create({"{a: 1}", "{a: 2}"});
- ASSERT_EQ(*source->getNext(), DOC("a" << 1));
- ASSERT_EQ(*source->getNext(), DOC("a" << 2));
+ ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 1));
+ ASSERT_DOCUMENT_EQ(*source->getNext(), DOC("a" << 2));
ASSERT(!source->getNext());
}
@@ -331,7 +332,7 @@ public:
// The limit's result is as expected.
boost::optional<Document> next = limit()->getNext();
ASSERT(bool(next));
- ASSERT_EQUALS(Value(1), next->getField("a"));
+ ASSERT_VALUE_EQ(Value(1), next->getField("a"));
// The limit is exhausted.
ASSERT(!limit()->getNext());
}
@@ -373,7 +374,7 @@ public:
// The limit is not exhauted.
boost::optional<Document> next = limit()->getNext();
ASSERT(bool(next));
- ASSERT_EQUALS(Value(1), next->getField("a"));
+ ASSERT_VALUE_EQ(Value(1), next->getField("a"));
// The limit is exhausted.
ASSERT(!limit()->getNext());
}
@@ -458,6 +459,7 @@ protected:
expressionContext->tempDir = _tempDir.path();
_group = DocumentSourceGroup::createFromBson(specElement, expressionContext);
+ _group->injectExpressionContext(expressionContext);
assertRoundTrips(_group);
}
DocumentSourceGroup* group() {
@@ -1014,19 +1016,19 @@ public:
auto res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id"), Value(0));
+ ASSERT_VALUE_EQ(res->getField("_id"), Value(0));
ASSERT_TRUE(group()->isStreaming());
res = source->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("a"), Value(1));
+ ASSERT_VALUE_EQ(res->getField("a"), Value(1));
assertExhausted(source);
res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id"), Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id"), Value(1));
assertExhausted(group());
@@ -1049,20 +1051,20 @@ public:
auto res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id")["x"], Value(1));
- ASSERT_EQUALS(res->getField("_id")["y"], Value(2));
+ ASSERT_VALUE_EQ(res->getField("_id")["x"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["y"], Value(2));
ASSERT_TRUE(group()->isStreaming());
res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id")["x"], Value(1));
- ASSERT_EQUALS(res->getField("_id")["y"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["x"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["y"], Value(1));
res = source->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("a"), Value(2));
- ASSERT_EQUALS(res->getField("b"), Value(1));
+ ASSERT_VALUE_EQ(res->getField("a"), Value(2));
+ ASSERT_VALUE_EQ(res->getField("b"), Value(1));
assertExhausted(source);
@@ -1089,13 +1091,13 @@ public:
auto res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id")["x"]["y"]["z"], Value(3));
+ ASSERT_VALUE_EQ(res->getField("_id")["x"]["y"]["z"], Value(3));
ASSERT_TRUE(group()->isStreaming());
res = source->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("a")["b"]["c"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("a")["b"]["c"], Value(1));
assertExhausted(source);
@@ -1125,16 +1127,16 @@ public:
auto res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id")["sub"]["x"], Value(1));
- ASSERT_EQUALS(res->getField("_id")["sub"]["y"], Value(1));
- ASSERT_EQUALS(res->getField("_id")["sub"]["z"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["x"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["y"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["z"], Value(1));
ASSERT_TRUE(group()->isStreaming());
res = source->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("a"), Value(2));
- ASSERT_EQUALS(res->getField("b"), Value(3));
+ ASSERT_VALUE_EQ(res->getField("a"), Value(2));
+ ASSERT_VALUE_EQ(res->getField("b"), Value(3));
BSONObjSet outputSort = group()->getOutputSorts();
@@ -1160,16 +1162,16 @@ public:
auto res = group()->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("_id")["sub"]["x"], Value(5));
- ASSERT_EQUALS(res->getField("_id")["sub"]["y"], Value(1));
- ASSERT_EQUALS(res->getField("_id")["sub"]["z"], Value("c"));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["x"], Value(5));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["y"], Value(1));
+ ASSERT_VALUE_EQ(res->getField("_id")["sub"]["z"], Value("c"));
ASSERT_TRUE(group()->isStreaming());
res = source->getNext();
ASSERT_TRUE(bool(res));
- ASSERT_EQUALS(res->getField("a"), Value(3));
- ASSERT_EQUALS(res->getField("b"), Value(1));
+ ASSERT_VALUE_EQ(res->getField("a"), Value(3));
+ ASSERT_VALUE_EQ(res->getField("b"), Value(1));
BSONObjSet outputSort = group()->getOutputSorts();
ASSERT_EQUALS(outputSort.size(), 2U);
@@ -1903,7 +1905,7 @@ public:
vector<Value> arr;
sort()->serializeToArray(arr);
- ASSERT_EQUALS(
+ ASSERT_VALUE_EQ(
Value(arr),
DOC_ARRAY(DOC("$sort" << DOC("a" << 1)) << DOC("$limit" << sort()->getLimit())));
@@ -3043,8 +3045,8 @@ using std::unique_ptr;
// Helpers to make a DocumentSourceMatch from a query object or json string
intrusive_ptr<DocumentSourceMatch> makeMatch(const BSONObj& query) {
- intrusive_ptr<DocumentSource> uncasted =
- DocumentSourceMatch::createFromBson(BSON("$match" << query).firstElement(), NULL);
+ intrusive_ptr<DocumentSource> uncasted = DocumentSourceMatch::createFromBson(
+ BSON("$match" << query).firstElement(), new ExpressionContext());
return dynamic_cast<DocumentSourceMatch*>(uncasted.get());
}
intrusive_ptr<DocumentSourceMatch> makeMatch(const string& queryJson) {
@@ -3458,11 +3460,11 @@ public:
ASSERT_EQUALS(explainedStages.size(), 2UL);
auto groupExplain = explainedStages[0];
- ASSERT_EQ(groupExplain["$group"], expectedGroupExplain);
+ ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain);
auto sortExplain = explainedStages[1];
auto expectedSortExplain = Value{Document{{"sortKey", Document{{"count", -1}}}}};
- ASSERT_EQ(sortExplain["$sort"], expectedSortExplain);
+ ASSERT_VALUE_EQ(sortExplain["$sort"], expectedSortExplain);
}
};
@@ -3553,11 +3555,11 @@ public:
Value{Document{{"_id", Document{{"$const", BSONNULL}}},
{countName, Document{{"$sum", Document{{"$const", 1}}}}}}};
auto groupExplain = explainedStages[0];
- ASSERT_EQ(groupExplain["$group"], expectedGroupExplain);
+ ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain);
Value expectedProjectExplain = Value{Document{{"_id", false}, {countName, true}}};
auto projectExplain = explainedStages[1];
- ASSERT_EQ(projectExplain["$project"], expectedProjectExplain);
+ ASSERT_VALUE_EQ(projectExplain["$project"], expectedProjectExplain);
}
};
@@ -3644,12 +3646,12 @@ public:
ASSERT_EQUALS(explainedStages.size(), 2UL);
auto groupExplain = explainedStages[0];
- ASSERT_EQ(groupExplain["$group"], expectedGroupExplain);
+ ASSERT_VALUE_EQ(groupExplain["$group"], expectedGroupExplain);
auto sortExplain = explainedStages[1];
auto expectedSortExplain = Value{Document{{"sortKey", Document{{"_id", 1}}}}};
- ASSERT_EQ(sortExplain["$sort"], expectedSortExplain);
+ ASSERT_VALUE_EQ(sortExplain["$sort"], expectedSortExplain);
}
};
diff --git a/src/mongo/db/pipeline/document_source_unwind.cpp b/src/mongo/db/pipeline/document_source_unwind.cpp
index b02723c1c94..82bb8f3adbe 100644
--- a/src/mongo/db/pipeline/document_source_unwind.cpp
+++ b/src/mongo/db/pipeline/document_source_unwind.cpp
@@ -174,11 +174,13 @@ intrusive_ptr<DocumentSourceUnwind> DocumentSourceUnwind::create(
const string& unwindPath,
bool preserveNullAndEmptyArrays,
const boost::optional<string>& indexPath) {
- return new DocumentSourceUnwind(expCtx,
- FieldPath(unwindPath),
- preserveNullAndEmptyArrays,
- indexPath ? FieldPath(*indexPath)
- : boost::optional<FieldPath>());
+ intrusive_ptr<DocumentSourceUnwind> source(
+ new DocumentSourceUnwind(expCtx,
+ FieldPath(unwindPath),
+ preserveNullAndEmptyArrays,
+ indexPath ? FieldPath(*indexPath) : boost::optional<FieldPath>()));
+ source->injectExpressionContext(expCtx);
+ return source;
}
boost::optional<Document> DocumentSourceUnwind::getNext() {
diff --git a/src/mongo/db/pipeline/document_value_test.cpp b/src/mongo/db/pipeline/document_value_test.cpp
index 207039aa90c..b9e942f5659 100644
--- a/src/mongo/db/pipeline/document_value_test.cpp
+++ b/src/mongo/db/pipeline/document_value_test.cpp
@@ -31,6 +31,7 @@
#include "mongo/platform/basic.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/field_path.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/dbtests/dbtests.h"
@@ -67,7 +68,7 @@ void assertRoundTrips(const Document& document1) {
Document document2 = fromBson(obj1);
BSONObj obj2 = toBson(document2);
ASSERT_EQUALS(obj1, obj2);
- ASSERT_EQUALS(document1, document2);
+ ASSERT_DOCUMENT_EQ(document1, document2);
}
TEST(DocumentConstruction, Default) {
@@ -185,26 +186,26 @@ public:
md.remove("c");
ASSERT(md.peek().empty());
ASSERT_EQUALS(0U, md.peek().size());
- ASSERT_EQUALS(md.peek(), Document());
+ ASSERT_DOCUMENT_EQ(md.peek(), Document());
ASSERT(!FieldIterator(md.peek()).more());
ASSERT(md.peek()["c"].missing());
assertRoundTrips(md.peek());
// Set a nested field using []
md["x"]["y"]["z"] = Value("nested");
- ASSERT_EQUALS(md.peek()["x"]["y"]["z"], Value("nested"));
+ ASSERT_VALUE_EQ(md.peek()["x"]["y"]["z"], Value("nested"));
// Set a nested field using setNestedField
FieldPath xxyyzz = string("xx.yy.zz");
md.setNestedField(xxyyzz, Value("nested"));
- ASSERT_EQUALS(md.peek().getNestedField(xxyyzz), Value("nested"));
+ ASSERT_VALUE_EQ(md.peek().getNestedField(xxyyzz), Value("nested"));
// Set a nested fields through an existing empty document
md["xxx"] = Value(Document());
md["xxx"]["yyy"] = Value(Document());
FieldPath xxxyyyzzz = string("xxx.yyy.zzz");
md.setNestedField(xxxyyyzzz, Value("nested"));
- ASSERT_EQUALS(md.peek().getNestedField(xxxyyyzzz), Value("nested"));
+ ASSERT_VALUE_EQ(md.peek().getNestedField(xxxyyyzzz), Value("nested"));
// Make sure nothing moved
ASSERT_EQUALS(apos, md.peek().positionOf("a"));
@@ -269,7 +270,7 @@ public:
MutableDocument cloneOnDemand(document);
// Check equality.
- ASSERT_EQUALS(document, cloneOnDemand.peek());
+ ASSERT_DOCUMENT_EQ(document, cloneOnDemand.peek());
// Check pointer equality of sub document.
ASSERT_EQUALS(document["a"].getDocument().getPtr(),
cloneOnDemand.peek()["a"].getDocument().getPtr());
@@ -277,21 +278,21 @@ public:
// Change field in clone and ensure the original document's field is unchanged.
cloneOnDemand.setField(StringData("a"), Value(2));
- ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b")));
+ ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b")));
// setNestedField and ensure the original document is unchanged.
cloneOnDemand.reset(document);
vector<Position> path;
- ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b"), &path));
+ ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b"), &path));
cloneOnDemand.setNestedField(path, Value(2));
- ASSERT_EQUALS(Value(1), document.getNestedField(FieldPath("a.b")));
- ASSERT_EQUALS(Value(2), cloneOnDemand.peek().getNestedField(FieldPath("a.b")));
- ASSERT_EQUALS(DOC("a" << DOC("b" << 1)), document);
- ASSERT_EQUALS(DOC("a" << DOC("b" << 2)), cloneOnDemand.freeze());
+ ASSERT_VALUE_EQ(Value(1), document.getNestedField(FieldPath("a.b")));
+ ASSERT_VALUE_EQ(Value(2), cloneOnDemand.peek().getNestedField(FieldPath("a.b")));
+ ASSERT_DOCUMENT_EQ(DOC("a" << DOC("b" << 1)), document);
+ ASSERT_DOCUMENT_EQ(DOC("a" << DOC("b" << 2)), cloneOnDemand.freeze());
}
};
@@ -301,7 +302,7 @@ public:
void run() {
Document document = fromBson(fromjson("{a:1,b:['ra',4],c:{z:1},d:'lal'}"));
Document clonedDocument = document.clone();
- ASSERT_EQUALS(document, clonedDocument);
+ ASSERT_DOCUMENT_EQ(document, clonedDocument);
}
};
@@ -404,7 +405,7 @@ public:
// logical equality
ASSERT_EQUALS(obj, obj2);
- ASSERT_EQUALS(doc, doc2);
+ ASSERT_DOCUMENT_EQ(doc, doc2);
// binary equality
ASSERT_EQUALS(obj.objsize(), obj2.objsize());
@@ -482,7 +483,7 @@ protected:
void assertRoundTrips(const Document& input) {
// Round trip to/from a buffer.
auto output = roundTrip(input);
- ASSERT_EQ(output, input);
+ ASSERT_DOCUMENT_EQ(output, input);
ASSERT_EQ(output.hasTextScore(), input.hasTextScore());
ASSERT_EQ(output.hasRandMetaField(), input.hasRandMetaField());
if (input.hasTextScore())
@@ -564,15 +565,15 @@ void assertRoundTrips(const Value& value1) {
Value value2 = fromBson(obj1);
BSONObj obj2 = toBson(value2);
ASSERT_EQUALS(obj1, obj2);
- ASSERT_EQUALS(value1, value2);
+ ASSERT_VALUE_EQ(value1, value2);
ASSERT_EQUALS(value1.getType(), value2.getType());
}
class BSONArrayTest {
public:
void run() {
- ASSERT_EQUALS(Value(BSON_ARRAY(1 << 2 << 3)), DOC_ARRAY(1 << 2 << 3));
- ASSERT_EQUALS(Value(BSONArray()), Value(vector<Value>()));
+ ASSERT_VALUE_EQ(Value(BSON_ARRAY(1 << 2 << 3)), DOC_ARRAY(1 << 2 << 3));
+ ASSERT_VALUE_EQ(Value(BSONArray()), Value(vector<Value>()));
}
};
@@ -1661,10 +1662,10 @@ public:
arrayOfMissing.serializeForSorter(bb);
BufReader reader(bb.buf(), bb.len());
- ASSERT_EQUALS(missing,
- Value::deserializeForSorter(reader, Value::SorterDeserializeSettings()));
- ASSERT_EQUALS(arrayOfMissing,
- Value::deserializeForSorter(reader, Value::SorterDeserializeSettings()));
+ ASSERT_VALUE_EQ(missing,
+ Value::deserializeForSorter(reader, Value::SorterDeserializeSettings()));
+ ASSERT_VALUE_EQ(arrayOfMissing,
+ Value::deserializeForSorter(reader, Value::SorterDeserializeSettings()));
}
};
} // namespace Value
diff --git a/src/mongo/db/pipeline/document_value_test_util.cpp b/src/mongo/db/pipeline/document_value_test_util.cpp
new file mode 100644
index 00000000000..223ea4c8f56
--- /dev/null
+++ b/src/mongo/db/pipeline/document_value_test_util.cpp
@@ -0,0 +1,67 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/document_value_test_util.h"
+
+namespace mongo {
+namespace unittest {
+
+#define _GENERATE_DOCVAL_CMP_FUNC(DOCVAL, NAME, COMPARATOR, OPERATOR) \
+ void assertComparison_##DOCVAL##NAME(const std::string& theFile, \
+ unsigned theLine, \
+ StringData aExpression, \
+ StringData bExpression, \
+ const DOCVAL& aValue, \
+ const DOCVAL& bValue) { \
+ if (!COMPARATOR().evaluate(aValue OPERATOR bValue)) { \
+ std::ostringstream os; \
+ os << "Expected [ " << aExpression << " " #OPERATOR " " << bExpression \
+ << " ] but found [ " << aValue << " " #OPERATOR " " << bValue << "]"; \
+ TestAssertionFailure(theFile, theLine, os.str()).stream(); \
+ } \
+ }
+
+_GENERATE_DOCVAL_CMP_FUNC(Value, EQ, ValueComparator, ==);
+_GENERATE_DOCVAL_CMP_FUNC(Value, LT, ValueComparator, <);
+_GENERATE_DOCVAL_CMP_FUNC(Value, LTE, ValueComparator, <=);
+_GENERATE_DOCVAL_CMP_FUNC(Value, GT, ValueComparator, >);
+_GENERATE_DOCVAL_CMP_FUNC(Value, GTE, ValueComparator, >=);
+_GENERATE_DOCVAL_CMP_FUNC(Value, NE, ValueComparator, !=);
+
+_GENERATE_DOCVAL_CMP_FUNC(Document, EQ, DocumentComparator, ==);
+_GENERATE_DOCVAL_CMP_FUNC(Document, LT, DocumentComparator, <);
+_GENERATE_DOCVAL_CMP_FUNC(Document, LTE, DocumentComparator, <=);
+_GENERATE_DOCVAL_CMP_FUNC(Document, GT, DocumentComparator, >);
+_GENERATE_DOCVAL_CMP_FUNC(Document, GTE, DocumentComparator, >=);
+_GENERATE_DOCVAL_CMP_FUNC(Document, NE, DocumentComparator, !=);
+#undef _GENERATE_DOCVAL_CMP_FUNC
+
+} // namespace unittest
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_value_test_util.h b/src/mongo/db/pipeline/document_value_test_util.h
new file mode 100644
index 00000000000..03b26e696af
--- /dev/null
+++ b/src/mongo/db/pipeline/document_value_test_util.h
@@ -0,0 +1,88 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include "mongo/db/pipeline/document_comparator.h"
+#include "mongo/db/pipeline/value_comparator.h"
+#include "mongo/unittest/unittest.h"
+
+/**
+ * Use to compare two instances of type Value under the default ValueComparator in unit tests.
+ */
+#define ASSERT_VALUE_EQ(a, b) _ASSERT_DOCVAL_COMPARISON(ValueEQ, a, b)
+#define ASSERT_VALUE_LT(a, b) _ASSERT_DOCVAL_COMPARISON(ValueLT, a, b)
+#define ASSERT_VALUE_LTE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueLTE, a, b)
+#define ASSERT_VALUE_GT(a, b) _ASSERT_DOCVAL_COMPARISON(ValueGT, a, b)
+#define ASSERT_VALUE_GTE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueGTE, a, b)
+#define ASSERT_VALUE_NE(a, b) _ASSERT_DOCVAL_COMPARISON(ValueNE, a, b)
+
+/**
+ * Use to compare two instances of type Document under the default DocumentComparator in unit tests.
+ */
+#define ASSERT_DOCUMENT_EQ(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentEQ, a, b)
+#define ASSERT_DOCUMENT_LT(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentLT, a, b)
+#define ASSERT_DOCUMENT_LTE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentLTE, a, b)
+#define ASSERT_DOCUMENT_GT(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentGT, a, b)
+#define ASSERT_DOCUMENT_GTE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentGTE, a, b)
+#define ASSERT_DOCUMENT_NE(a, b) _ASSERT_DOCVAL_COMPARISON(DocumentNE, a, b)
+
+/**
+ * Document/Value comparison utility macro. Do not use directly.
+ */
+#define _ASSERT_DOCVAL_COMPARISON(NAME, a, b) \
+ ::mongo::unittest::assertComparison_##NAME(__FILE__, __LINE__, #a, #b, a, b)
+
+namespace mongo {
+namespace unittest {
+
+#define _DECLARE_DOCVAL_CMP_FUNC(DOCVAL, NAME) \
+ void assertComparison_##DOCVAL##NAME(const std::string& theFile, \
+ unsigned theLine, \
+ StringData aExpression, \
+ StringData bExpression, \
+ const DOCVAL& aValue, \
+ const DOCVAL& bValue);
+
+_DECLARE_DOCVAL_CMP_FUNC(Value, EQ);
+_DECLARE_DOCVAL_CMP_FUNC(Value, LT);
+_DECLARE_DOCVAL_CMP_FUNC(Value, LTE);
+_DECLARE_DOCVAL_CMP_FUNC(Value, GT);
+_DECLARE_DOCVAL_CMP_FUNC(Value, GTE);
+_DECLARE_DOCVAL_CMP_FUNC(Value, NE);
+
+_DECLARE_DOCVAL_CMP_FUNC(Document, EQ);
+_DECLARE_DOCVAL_CMP_FUNC(Document, LT);
+_DECLARE_DOCVAL_CMP_FUNC(Document, LTE);
+_DECLARE_DOCVAL_CMP_FUNC(Document, GT);
+_DECLARE_DOCVAL_CMP_FUNC(Document, GTE);
+_DECLARE_DOCVAL_CMP_FUNC(Document, NE);
+#undef _DECLARE_DOCVAL_CMP_FUNC
+
+} // namespace unittest
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/document_value_test_util_self_test.cpp b/src/mongo/db/pipeline/document_value_test_util_self_test.cpp
new file mode 100644
index 00000000000..cbe55346a29
--- /dev/null
+++ b/src/mongo/db/pipeline/document_value_test_util_self_test.cpp
@@ -0,0 +1,92 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/document_value_test_util.h"
+
+#include "mongo/db/query/collation/collator_interface_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+TEST(DocumentValueTestUtilSelfTest, DocumentEQ) {
+ ASSERT_DOCUMENT_EQ(Document({{"foo", "bar"}}), Document({{"foo", "bar"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, DocumentNE) {
+ ASSERT_DOCUMENT_NE(Document({{"foo", "bar"}}), Document({{"foo", "baz"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, DocumentLT) {
+ ASSERT_DOCUMENT_LT(Document({{"foo", "bar"}}), Document({{"foo", "baz"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, DocumentLTE) {
+ ASSERT_DOCUMENT_LTE(Document({{"foo", "bar"}}), Document({{"foo", "baz"}}));
+ ASSERT_DOCUMENT_LTE(Document({{"foo", "bar"}}), Document({{"foo", "bar"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, DocumentGT) {
+ ASSERT_DOCUMENT_GT(Document({{"foo", "baz"}}), Document({{"foo", "bar"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, DocumentGTE) {
+ ASSERT_DOCUMENT_GTE(Document({{"foo", "baz"}}), Document({{"foo", "bar"}}));
+ ASSERT_DOCUMENT_GTE(Document({{"foo", "bar"}}), Document({{"foo", "bar"}}));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueEQ) {
+ ASSERT_VALUE_EQ(Value("bar"), Value("bar"));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueNE) {
+ ASSERT_VALUE_NE(Value("bar"), Value("baz"));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueLT) {
+ ASSERT_VALUE_LT(Value("bar"), Value("baz"));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueLTE) {
+ ASSERT_VALUE_LTE(Value("bar"), Value("baz"));
+ ASSERT_VALUE_LTE(Value("bar"), Value("bar"));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueGT) {
+ ASSERT_VALUE_GT(Value("baz"), Value("bar"));
+}
+
+TEST(DocumentValueTestUtilSelfTest, ValueGTE) {
+ ASSERT_VALUE_GTE(Value("baz"), Value("bar"));
+ ASSERT_VALUE_GTE(Value("bar"), Value("bar"));
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index b8f81dd6329..40affa27bbe 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -456,7 +456,8 @@ intrusive_ptr<Expression> ExpressionAnd::optimize() {
*/
bool last = pConst->getValue().coerceToBool();
if (!last) {
- intrusive_ptr<ExpressionConstant> pFinal(ExpressionConstant::create(Value(false)));
+ intrusive_ptr<ExpressionConstant> pFinal(
+ ExpressionConstant::create(getExpressionContext(), Value(false)));
return pFinal;
}
@@ -467,7 +468,8 @@ intrusive_ptr<Expression> ExpressionAnd::optimize() {
the result will be a boolean.
*/
if (n == 2) {
- intrusive_ptr<Expression> pFinal(ExpressionCoerceToBool::create(pAnd->vpOperand[0]));
+ intrusive_ptr<Expression> pFinal(
+ ExpressionCoerceToBool::create(getExpressionContext(), pAnd->vpOperand[0]));
return pFinal;
}
@@ -612,8 +614,9 @@ const char* ExpressionCeil::getOpName() const {
/* -------------------- ExpressionCoerceToBool ------------------------- */
intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create(
- const intrusive_ptr<Expression>& pExpression) {
+ const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) {
intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(pExpression));
+ pNew->injectExpressionContext(expCtx);
return pNew;
}
@@ -653,6 +656,11 @@ Value ExpressionCoerceToBool::serialize(bool explain) const {
return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain))));
}
+void ExpressionCoerceToBool::doInjectExpressionContext() {
+ // Inject our ExpressionContext into the operand.
+ pExpression->injectExpressionContext(getExpressionContext());
+}
+
/* ----------------------- ExpressionCompare --------------------------- */
REGISTER_EXPRESSION(cmp,
@@ -854,8 +862,10 @@ intrusive_ptr<Expression> ExpressionConstant::parse(BSONElement exprElement,
}
-intrusive_ptr<ExpressionConstant> ExpressionConstant::create(const Value& pValue) {
+intrusive_ptr<ExpressionConstant> ExpressionConstant::create(
+ const intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue) {
intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(pValue));
+ pEC->injectExpressionContext(expCtx);
return pEC;
}
@@ -1084,6 +1094,10 @@ void ExpressionDateToString::addDependencies(DepsTracker* deps) const {
_date->addDependencies(deps);
}
+void ExpressionDateToString::doInjectExpressionContext() {
+ _date->injectExpressionContext(getExpressionContext());
+}
+
/* ---------------------- ExpressionDayOfMonth ------------------------- */
Value ExpressionDayOfMonth::evaluateInternal(Variables* vars) const {
@@ -1235,6 +1249,12 @@ Value ExpressionObject::serialize(bool explain) const {
return outputDoc.freezeToValue();
}
+void ExpressionObject::doInjectExpressionContext() {
+ for (auto&& pair : _expressions) {
+ pair.second->injectExpressionContext(getExpressionContext());
+ }
+}
+
/* --------------------- ExpressionFieldPath --------------------------- */
// this is the old deprecated version only used by tests not using variables
@@ -1454,6 +1474,11 @@ void ExpressionFilter::addDependencies(DepsTracker* deps) const {
_filter->addDependencies(deps);
}
+void ExpressionFilter::doInjectExpressionContext() {
+ _input->injectExpressionContext(getExpressionContext());
+ _filter->injectExpressionContext(getExpressionContext());
+}
+
/* ------------------------- ExpressionFloor -------------------------- */
Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const {
@@ -1569,6 +1594,10 @@ void ExpressionLet::addDependencies(DepsTracker* deps) const {
_subExpression->addDependencies(deps);
}
+void ExpressionLet::doInjectExpressionContext() {
+ _subExpression->injectExpressionContext(getExpressionContext());
+}
+
/* ------------------------- ExpressionMap ----------------------------- */
@@ -1670,6 +1699,11 @@ void ExpressionMap::addDependencies(DepsTracker* deps) const {
_each->addDependencies(deps);
}
+void ExpressionMap::doInjectExpressionContext() {
+ _input->injectExpressionContext(getExpressionContext());
+ _each->injectExpressionContext(getExpressionContext());
+}
+
/* ------------------------- ExpressionMeta ----------------------------- */
REGISTER_EXPRESSION(meta, ExpressionMeta::parse);
@@ -1917,7 +1951,7 @@ Value ExpressionIn::evaluateInternal(Variables* vars) const {
<< typeName(arrayOfValues.getType()),
arrayOfValues.isArray());
for (auto&& value : arrayOfValues.getArray()) {
- if (argument == value) {
+ if (getExpressionContext()->getValueComparator().evaluate(argument == value)) {
return Value(true);
}
}
@@ -1984,7 +2018,7 @@ Value ExpressionIndexOfArray::evaluateInternal(Variables* vars) const {
}
for (size_t i = startIndex; i < endIndex; i++) {
- if (array[i] == searchItem) {
+ if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) {
return Value(static_cast<int>(i));
}
}
@@ -2260,7 +2294,8 @@ intrusive_ptr<Expression> ExpressionNary::optimize() {
// expression.
if (constOperandCount == vpOperand.size()) {
Variables emptyVars;
- return intrusive_ptr<Expression>(ExpressionConstant::create(evaluateInternal(&emptyVars)));
+ return intrusive_ptr<Expression>(
+ ExpressionConstant::create(getExpressionContext(), evaluateInternal(&emptyVars)));
}
// If the expression is associative, we can collapse all the consecutive constant operands into
@@ -2304,8 +2339,8 @@ intrusive_ptr<Expression> ExpressionNary::optimize() {
ExpressionVector vpOperandSave = std::move(vpOperand);
vpOperand = std::move(constExpressions);
Variables emptyVars;
- optimizedOperands.emplace_back(
- ExpressionConstant::create(evaluateInternal(&emptyVars)));
+ optimizedOperands.emplace_back(ExpressionConstant::create(
+ getExpressionContext(), evaluateInternal(&emptyVars)));
vpOperand = std::move(vpOperandSave);
} else {
optimizedOperands.insert(
@@ -2321,7 +2356,7 @@ intrusive_ptr<Expression> ExpressionNary::optimize() {
vpOperand = std::move(constExpressions);
Variables emptyVars;
optimizedOperands.emplace_back(
- ExpressionConstant::create(evaluateInternal(&emptyVars)));
+ ExpressionConstant::create(getExpressionContext(), evaluateInternal(&emptyVars)));
} else {
optimizedOperands.insert(
optimizedOperands.end(), constExpressions.begin(), constExpressions.end());
@@ -2352,6 +2387,12 @@ Value ExpressionNary::serialize(bool explain) const {
return Value(DOC(getOpName() << array));
}
+void ExpressionNary::doInjectExpressionContext() {
+ for (auto&& operand : vpOperand) {
+ operand->injectExpressionContext(getExpressionContext());
+ }
+}
+
/* ------------------------- ExpressionNot ----------------------------- */
Value ExpressionNot::evaluateInternal(Variables* vars) const {
@@ -2407,7 +2448,8 @@ intrusive_ptr<Expression> ExpressionOr::optimize() {
*/
bool last = pConst->getValue().coerceToBool();
if (last) {
- intrusive_ptr<ExpressionConstant> pFinal(ExpressionConstant::create(Value(true)));
+ intrusive_ptr<ExpressionConstant> pFinal(
+ ExpressionConstant::create(getExpressionContext(), Value(true)));
return pFinal;
}
@@ -2418,7 +2460,8 @@ intrusive_ptr<Expression> ExpressionOr::optimize() {
the result will be a boolean.
*/
if (n == 2) {
- intrusive_ptr<Expression> pFinal(ExpressionCoerceToBool::create(pOr->vpOperand[0]));
+ intrusive_ptr<Expression> pFinal(
+ ExpressionCoerceToBool::create(getExpressionContext(), pOr->vpOperand[0]));
return pFinal;
}
@@ -2738,6 +2781,12 @@ Value ExpressionReduce::serialize(bool explain) const {
{"in", _in->serialize(explain)}}}});
}
+void ExpressionReduce::doInjectExpressionContext() {
+ _input->injectExpressionContext(getExpressionContext());
+ _initial->injectExpressionContext(getExpressionContext());
+ _in->injectExpressionContext(getExpressionContext());
+}
+
/* ------------------------ ExpressionReverseArray ------------------------ */
Value ExpressionReverseArray::evaluateInternal(Variables* vars) const {
@@ -2779,9 +2828,11 @@ const char* ExpressionSecond::getOpName() const {
}
namespace {
-ValueSet arrayToSet(const Value& val) {
+ValueSet arrayToSet(const Value& val, const ValueComparator& valueComparator) {
const vector<Value>& array = val.getArray();
- return ValueSet(array.begin(), array.end());
+ ValueSet valueSet = valueComparator.makeOrderedValueSet();
+ valueSet.insert(array.begin(), array.end());
+ return valueSet;
}
}
@@ -2806,7 +2857,7 @@ Value ExpressionSetDifference::evaluateInternal(Variables* vars) const {
<< typeName(rhs.getType()),
rhs.isArray());
- ValueSet rhsSet = arrayToSet(rhs);
+ ValueSet rhsSet = arrayToSet(rhs, getExpressionContext()->getValueComparator());
const vector<Value>& lhsArray = lhs.getArray();
vector<Value> returnVec;
@@ -2835,7 +2886,8 @@ void ExpressionSetEquals::validateArguments(const ExpressionVector& args) const
Value ExpressionSetEquals::evaluateInternal(Variables* vars) const {
const size_t n = vpOperand.size();
- std::set<Value> lhs;
+ const auto& valueComparator = getExpressionContext()->getValueComparator();
+ ValueSet lhs = valueComparator.makeOrderedValueSet();
for (size_t i = 0; i < n; i++) {
const Value nextEntry = vpOperand[i]->evaluateInternal(vars);
@@ -2848,8 +2900,13 @@ Value ExpressionSetEquals::evaluateInternal(Variables* vars) const {
if (i == 0) {
lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
} else {
- const std::set<Value> rhs(nextEntry.getArray().begin(), nextEntry.getArray().end());
- if (lhs != rhs) {
+ ValueSet rhs = valueComparator.makeOrderedValueSet();
+ rhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
+ if (lhs.size() != rhs.size()) {
+ return Value(false);
+ }
+
+ if (!std::equal(lhs.begin(), lhs.end(), rhs.begin(), valueComparator.getEqualTo())) {
return Value(false);
}
}
@@ -2866,7 +2923,8 @@ const char* ExpressionSetEquals::getOpName() const {
Value ExpressionSetIntersection::evaluateInternal(Variables* vars) const {
const size_t n = vpOperand.size();
- ValueSet currentIntersection;
+ const auto& valueComparator = getExpressionContext()->getValueComparator();
+ ValueSet currentIntersection = valueComparator.makeOrderedValueSet();
for (size_t i = 0; i < n; i++) {
const Value nextEntry = vpOperand[i]->evaluateInternal(vars);
if (nextEntry.nullish()) {
@@ -2881,7 +2939,7 @@ Value ExpressionSetIntersection::evaluateInternal(Variables* vars) const {
if (i == 0) {
currentIntersection.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
} else {
- ValueSet nextSet = arrayToSet(nextEntry);
+ ValueSet nextSet = arrayToSet(nextEntry, valueComparator);
if (currentIntersection.size() > nextSet.size()) {
// to iterate over whichever is the smaller set
nextSet.swap(currentIntersection);
@@ -2939,7 +2997,8 @@ Value ExpressionSetIsSubset::evaluateInternal(Variables* vars) const {
<< typeName(rhs.getType()),
rhs.isArray());
- return setIsSubsetHelper(lhs.getArray(), arrayToSet(rhs));
+ return setIsSubsetHelper(lhs.getArray(),
+ arrayToSet(rhs, getExpressionContext()->getValueComparator()));
}
/**
@@ -2988,7 +3047,10 @@ intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() {
<< typeName(rhs.getType()),
rhs.isArray());
- return new Optimized(arrayToSet(rhs), vpOperand);
+ intrusive_ptr<Expression> optimizedWithConstant(new Optimized(
+ arrayToSet(rhs, getExpressionContext()->getValueComparator()), vpOperand));
+ optimizedWithConstant->injectExpressionContext(getExpressionContext());
+ return optimizedWithConstant;
}
return optimized;
}
@@ -3001,7 +3063,7 @@ const char* ExpressionSetIsSubset::getOpName() const {
/* ----------------------- ExpressionSetUnion ---------------------------- */
Value ExpressionSetUnion::evaluateInternal(Variables* vars) const {
- ValueSet unionedSet;
+ ValueSet unionedSet = getExpressionContext()->getValueComparator().makeOrderedValueSet();
const size_t n = vpOperand.size();
for (size_t i = 0; i < n; i++) {
const Value newEntries = vpOperand[i]->evaluateInternal(vars);
@@ -3606,6 +3668,17 @@ Value ExpressionSwitch::serialize(bool explain) const {
return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}});
}
+void ExpressionSwitch::doInjectExpressionContext() {
+ if (_default) {
+ _default->injectExpressionContext(getExpressionContext());
+ }
+
+ for (auto&& pair : _branches) {
+ pair.first->injectExpressionContext(getExpressionContext());
+ pair.second->injectExpressionContext(getExpressionContext());
+ }
+}
+
/* ------------------------- ExpressionToLower ----------------------------- */
Value ExpressionToLower::evaluateInternal(Variables* vars) const {
@@ -4032,6 +4105,16 @@ void ExpressionZip::addDependencies(DepsTracker* deps) const {
});
}
+void ExpressionZip::doInjectExpressionContext() {
+ for (auto&& expr : _inputs) {
+ expr->injectExpressionContext(getExpressionContext());
+ }
+
+ for (auto&& expr : _defaults) {
+ expr->injectExpressionContext(getExpressionContext());
+ }
+}
+
const char* ExpressionZip::getOpName() const {
return "$zip";
}
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index c4ac4b07704..4a849675300 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -38,6 +38,7 @@
#include "mongo/base/init.h"
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/field_path.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/stdx/functional.h"
@@ -281,8 +282,31 @@ public:
*/
static void registerExpression(std::string key, Parser parser);
+ /**
+ * Injects the ExpressionContext so that it may be used during evaluation of the Expression.
+ * Construction of expressions is done at parse time, but the ExpressionContext isn't finalized
+ * until later, at which point it is injected using this method.
+ */
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {
+ _expCtx = expCtx;
+ doInjectExpressionContext();
+ }
+
protected:
typedef std::vector<boost::intrusive_ptr<Expression>> ExpressionVector;
+
+ /**
+ * Expressions which need to update their internal state when attaching to a new
+ * ExpressionContext should override this method.
+ */
+ virtual void doInjectExpressionContext() {}
+
+ const boost::intrusive_ptr<ExpressionContext>& getExpressionContext() const {
+ return _expCtx;
+ }
+
+private:
+ boost::intrusive_ptr<ExpressionContext> _expCtx;
};
@@ -322,6 +346,11 @@ public:
static ExpressionVector parseArguments(BSONElement bsonExpr, const VariablesParseState& vps);
+ // TODO SERVER-23349: Currently there are subclasses which derive from this base class that
+ // require custom logic for expression context injection. Consider making those classes inherit
+ // directly from Expression so that this method can be marked 'final' rather than 'override'.
+ void doInjectExpressionContext() override;
+
protected:
ExpressionNary() {}
@@ -538,8 +567,11 @@ public:
Value serialize(bool explain) const final;
static boost::intrusive_ptr<ExpressionCoerceToBool> create(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx,
const boost::intrusive_ptr<Expression>& pExpression);
+ void doInjectExpressionContext() final;
+
private:
explicit ExpressionCoerceToBool(const boost::intrusive_ptr<Expression>& pExpression);
@@ -619,7 +651,9 @@ public:
const char* getOpName() const;
- static boost::intrusive_ptr<ExpressionConstant> create(const Value& pValue);
+ static boost::intrusive_ptr<ExpressionConstant> create(
+ const boost::intrusive_ptr<ExpressionContext>& expCtx, const Value& pValue);
+
static boost::intrusive_ptr<Expression> parse(BSONElement bsonExpr,
const VariablesParseState& vps);
@@ -647,6 +681,8 @@ public:
static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps);
+ void doInjectExpressionContext() final;
+
private:
ExpressionDateToString(const std::string& format, // the format string
boost::intrusive_ptr<Expression> date); // the date to format
@@ -777,6 +813,8 @@ public:
static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps);
+ void doInjectExpressionContext() final;
+
private:
ExpressionFilter(std::string varName,
Variables::Id varId,
@@ -859,6 +897,8 @@ public:
static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps);
+ void doInjectExpressionContext() final;
+
struct NameAndExpression {
NameAndExpression() {}
NameAndExpression(std::string name, boost::intrusive_ptr<Expression> expression)
@@ -901,6 +941,8 @@ public:
static boost::intrusive_ptr<Expression> parse(BSONElement expr, const VariablesParseState& vps);
+ void doInjectExpressionContext() final;
+
private:
ExpressionMap(
const std::string& varName, // name of variable to set
@@ -1026,6 +1068,8 @@ public:
return _expressions;
}
+ void doInjectExpressionContext() final;
+
private:
ExpressionObject(
std::vector<std::pair<std::string, boost::intrusive_ptr<Expression>>>&& expressions);
@@ -1072,6 +1116,8 @@ public:
const VariablesParseState& vpsIn);
Value serialize(bool explain) const final;
+ void doInjectExpressionContext() final;
+
private:
boost::intrusive_ptr<Expression> _input;
boost::intrusive_ptr<Expression> _initial;
@@ -1242,6 +1288,8 @@ public:
Value serialize(bool explain) const final;
const char* getOpName() const final;
+ void doInjectExpressionContext() final;
+
private:
using ExpressionPair =
std::pair<boost::intrusive_ptr<Expression>, boost::intrusive_ptr<Expression>>;
@@ -1337,6 +1385,8 @@ public:
Value serialize(bool explain) const final;
const char* getOpName() const final;
+ void doInjectExpressionContext() final;
+
private:
bool _useLongestLength = false;
ExpressionVector _inputs;
diff --git a/src/mongo/db/pipeline/expression_context.cpp b/src/mongo/db/pipeline/expression_context.cpp
index d7c5a6708a0..5992b8aa708 100644
--- a/src/mongo/db/pipeline/expression_context.cpp
+++ b/src/mongo/db/pipeline/expression_context.cpp
@@ -45,7 +45,7 @@ ExpressionContext::ExpressionContext(OperationContext* opCtx, const AggregationR
auto statusWithCollator =
CollatorFactoryInterface::get(opCtx->getServiceContext())->makeFromBSON(collation);
uassertStatusOK(statusWithCollator.getStatus());
- collator = std::move(statusWithCollator.getValue());
+ setCollator(std::move(statusWithCollator.getValue()));
}
}
@@ -56,4 +56,13 @@ void ExpressionContext::checkForInterrupt() {
interruptCounter = kInterruptCheckPeriod;
}
}
+
+void ExpressionContext::setCollator(std::unique_ptr<CollatorInterface> coll) {
+ _collator = std::move(coll);
+
+ // Document/Value comparisons must be aware of the collation.
+ _documentComparator = DocumentComparator(_collator.get());
+ _valueComparator = ValueComparator(_collator.get());
+}
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h
index 90d5bad3a90..285318f7f7a 100644
--- a/src/mongo/db/pipeline/expression_context.h
+++ b/src/mongo/db/pipeline/expression_context.h
@@ -35,6 +35,8 @@
#include "mongo/db/namespace_string.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/pipeline/aggregation_request.h"
+#include "mongo/db/pipeline/document_comparator.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/db/query/collation/collator_interface.h"
#include "mongo/util/intrusive_counter.h"
@@ -42,6 +44,8 @@ namespace mongo {
struct ExpressionContext : public IntrusiveCounterUnsigned {
public:
+ ExpressionContext() = default;
+
ExpressionContext(OperationContext* opCtx, const AggregationRequest& request);
/**
@@ -50,11 +54,25 @@ public:
*/
void checkForInterrupt();
- bool isExplain;
- bool inShard;
+ void setCollator(std::unique_ptr<CollatorInterface> coll);
+
+ const CollatorInterface* getCollator() const {
+ return _collator.get();
+ }
+
+ const DocumentComparator& getDocumentComparator() const {
+ return _documentComparator;
+ }
+
+ const ValueComparator& getValueComparator() const {
+ return _valueComparator;
+ }
+
+ bool isExplain = false;
+ bool inShard = false;
bool inRouter = false;
- bool extSortAllowed;
- bool bypassDocumentValidation;
+ bool extSortAllowed = false;
+ bool bypassDocumentValidation = false;
NamespaceString ns;
std::string tempDir; // Defaults to empty to prevent external sorting in mongos.
@@ -65,11 +83,17 @@ public:
// collation.
const BSONObj collation;
+ static const int kInterruptCheckPeriod = 128;
+ int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus
+
+private:
// Collator used to compare elements. 'collator' is initialized from 'collation', except in the
// case where 'collation' is empty and there is a collection default collation.
- std::unique_ptr<CollatorInterface> collator;
+ std::unique_ptr<CollatorInterface> _collator;
- static const int kInterruptCheckPeriod = 128;
- int interruptCounter = kInterruptCheckPeriod; // when 0, check interruptStatus
+ // Used for all comparisons of Document/Value during execution of the aggregation operation.
+ DocumentComparator _documentComparator;
+ ValueComparator _valueComparator;
};
-}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 09c4685df4f..9828e719aab 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -32,7 +32,10 @@
#include "mongo/config.h"
#include "mongo/db/pipeline/accumulator.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/pipeline/expression_context.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/dbtests/dbtests.h"
#include "mongo/unittest/unittest.h"
@@ -57,11 +60,14 @@ static void assertExpectedResults(string expression,
initializer_list<pair<vector<Value>, Value>> operations) {
for (auto&& op : operations) {
try {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
const BSONObj obj = BSON(expression << Value(op.first));
- Value result = Expression::parseExpression(obj, vps)->evaluate(Document());
- ASSERT_EQUALS(op.second, result);
+ auto expression = Expression::parseExpression(obj, vps);
+ expression->injectExpressionContext(expCtx);
+ Value result = expression->evaluate(Document());
+ ASSERT_VALUE_EQ(op.second, result);
ASSERT_EQUALS(op.second.getType(), result.getType());
} catch (...) {
log() << "failed with arguments: " << Value(op.first);
@@ -127,13 +133,13 @@ Value valueFromBson(BSONObj obj) {
template <typename T>
intrusive_ptr<ExpressionConstant> makeConstant(T&& val) {
- return ExpressionConstant::create(Value(std::forward<T>(val)));
+ return ExpressionConstant::create(nullptr, Value(std::forward<T>(val)));
}
class ExpressionBaseTest : public unittest::Test {
public:
void addOperand(intrusive_ptr<ExpressionNary> expr, Value arg) {
- expr->addOperand(ExpressionConstant::create(arg));
+ expr->addOperand(ExpressionConstant::create(nullptr, arg));
}
};
@@ -141,7 +147,7 @@ class ExpressionNaryTestOneArg : public ExpressionBaseTest {
public:
virtual void assertEvaluates(Value input, Value output) {
addOperand(_expr, input);
- ASSERT_EQUALS(output, _expr->evaluate(Document()));
+ ASSERT_VALUE_EQ(output, _expr->evaluate(Document()));
ASSERT_EQUALS(output.getType(), _expr->evaluate(Document()).getType());
}
@@ -231,7 +237,7 @@ protected:
};
TEST_F(ExpressionNaryTest, AddedConstantOperandIsSerialized) {
- _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(9)));
+ _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(9)));
assertContents(_notAssociativeNorCommutative, BSON_ARRAY(9));
}
@@ -245,7 +251,7 @@ TEST_F(ExpressionNaryTest, ValidateEmptyDependencies) {
}
TEST_F(ExpressionNaryTest, ValidateConstantExpressionDependency) {
- _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(1)));
+ _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(1)));
assertDependencies(_notAssociativeNorCommutative, BSONArray());
}
@@ -269,13 +275,13 @@ TEST_F(ExpressionNaryTest, ValidateObjectExpressionDependency) {
}
TEST_F(ExpressionNaryTest, SerializationToBsonObj) {
- _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(5)));
+ _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(5)));
ASSERT_EQUALS(BSON("foo" << BSON("$testable" << BSON_ARRAY(BSON("$const" << 5)))),
BSON("foo" << _notAssociativeNorCommutative->serialize(false)));
}
TEST_F(ExpressionNaryTest, SerializationToBsonArr) {
- _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(Value(5)));
+ _notAssociativeNorCommutative->addOperand(ExpressionConstant::create(nullptr, Value(5)));
ASSERT_EQUALS(constify(BSON_ARRAY(BSON("$testable" << BSON_ARRAY(5)))),
BSON_ARRAY(_notAssociativeNorCommutative->serialize(false)));
}
@@ -925,7 +931,7 @@ class NullDocument {
public:
void run() {
intrusive_ptr<ExpressionNary> expression = new ExpressionAdd();
- expression->addOperand(ExpressionConstant::create(Value(2)));
+ expression->addOperand(ExpressionConstant::create(nullptr, Value(2)));
ASSERT_EQUALS(BSON("" << 2), toBson(expression->evaluate(Document())));
}
};
@@ -943,7 +949,7 @@ class String {
public:
void run() {
intrusive_ptr<ExpressionNary> expression = new ExpressionAdd();
- expression->addOperand(ExpressionConstant::create(Value("a")));
+ expression->addOperand(ExpressionConstant::create(nullptr, Value("a")));
ASSERT_THROWS(expression->evaluate(Document()), UserException);
}
};
@@ -953,14 +959,14 @@ class Bool {
public:
void run() {
intrusive_ptr<ExpressionNary> expression = new ExpressionAdd();
- expression->addOperand(ExpressionConstant::create(Value(true)));
+ expression->addOperand(ExpressionConstant::create(nullptr, Value(true)));
ASSERT_THROWS(expression->evaluate(Document()), UserException);
}
};
class SingleOperandBase : public ExpectedResultBase {
void populateOperands(intrusive_ptr<ExpressionNary>& expression) {
- expression->addOperand(ExpressionConstant::create(valueFromBson(operand())));
+ expression->addOperand(ExpressionConstant::create(nullptr, valueFromBson(operand())));
}
BSONObj expectedResult() {
return operand();
@@ -1031,9 +1037,9 @@ public:
protected:
void populateOperands(intrusive_ptr<ExpressionNary>& expression) {
expression->addOperand(
- ExpressionConstant::create(valueFromBson(_reverse ? operand2() : operand1())));
+ ExpressionConstant::create(nullptr, valueFromBson(_reverse ? operand2() : operand1())));
expression->addOperand(
- ExpressionConstant::create(valueFromBson(_reverse ? operand1() : operand2())));
+ ExpressionConstant::create(nullptr, valueFromBson(_reverse ? operand1() : operand2())));
}
virtual BSONObj operand1() = 0;
virtual BSONObj operand2() = 0;
@@ -1185,11 +1191,13 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObject = BSON("" << spec());
BSONElement specElement = specObject.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
ASSERT_EQUALS(BSON("" << expectedResult()),
toBson(expression->evaluate(fromBson(BSON("a" << 1)))));
@@ -1207,11 +1215,13 @@ class OptimizeBase {
public:
virtual ~OptimizeBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObject = BSON("" << spec());
BSONElement specElement = specObject.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
intrusive_ptr<Expression> optimized = expression->optimize();
ASSERT_EQUALS(expectedOptimized(), expressionToBson(optimized));
@@ -1483,8 +1493,8 @@ namespace CoerceToBool {
class EvaluateTrue {
public:
void run() {
- intrusive_ptr<Expression> nested = ExpressionConstant::create(Value(5));
- intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested);
+ intrusive_ptr<Expression> nested = ExpressionConstant::create(nullptr, Value(5));
+ intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested);
ASSERT(expression->evaluate(Document()).getBool());
}
};
@@ -1493,8 +1503,8 @@ public:
class EvaluateFalse {
public:
void run() {
- intrusive_ptr<Expression> nested = ExpressionConstant::create(Value(0));
- intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested);
+ intrusive_ptr<Expression> nested = ExpressionConstant::create(nullptr, Value(0));
+ intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested);
ASSERT(!expression->evaluate(Document()).getBool());
}
};
@@ -1504,7 +1514,7 @@ class Dependencies {
public:
void run() {
intrusive_ptr<Expression> nested = ExpressionFieldPath::create("a.b");
- intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nested);
+ intrusive_ptr<Expression> expression = ExpressionCoerceToBool::create(nullptr, nested);
DepsTracker dependencies;
expression->addDependencies(&dependencies);
ASSERT_EQUALS(1U, dependencies.fields.size());
@@ -1519,7 +1529,7 @@ class AddToBsonObj {
public:
void run() {
intrusive_ptr<Expression> expression =
- ExpressionCoerceToBool::create(ExpressionFieldPath::create("foo"));
+ ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo"));
// serialized as $and because CoerceToBool isn't an ExpressionNary
assertBinaryEqual(fromjson("{field:{$and:['$foo']}}"), toBsonObj(expression));
@@ -1536,7 +1546,7 @@ class AddToBsonArray {
public:
void run() {
intrusive_ptr<Expression> expression =
- ExpressionCoerceToBool::create(ExpressionFieldPath::create("foo"));
+ ExpressionCoerceToBool::create(nullptr, ExpressionFieldPath::create("foo"));
// serialized as $and because CoerceToBool isn't an ExpressionNary
assertBinaryEqual(BSON_ARRAY(fromjson("{$and:['$foo']}")), toBsonArray(expression));
@@ -1840,7 +1850,7 @@ public:
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
- ASSERT_EQUALS(expression->evaluate(Document()), Value(true));
+ ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true));
}
};
@@ -1971,7 +1981,7 @@ namespace Constant {
class Create {
public:
void run() {
- intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5));
+ intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5));
assertBinaryEqual(BSON("" << 5), toBson(expression->evaluate(Document())));
}
};
@@ -1996,7 +2006,7 @@ public:
class Optimize {
public:
void run() {
- intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5));
+ intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5));
// An attempt to optimize returns the Expression itself.
ASSERT_EQUALS(expression, expression->optimize());
}
@@ -2006,7 +2016,7 @@ public:
class Dependencies {
public:
void run() {
- intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5));
+ intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5));
DepsTracker dependencies;
expression->addDependencies(&dependencies);
ASSERT_EQUALS(0U, dependencies.fields.size());
@@ -2019,7 +2029,7 @@ public:
class AddToBsonObj {
public:
void run() {
- intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5));
+ intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5));
// The constant is replaced with a $ expression.
assertBinaryEqual(BSON("field" << BSON("$const" << 5)), toBsonObj(expression));
}
@@ -2034,7 +2044,7 @@ private:
class AddToBsonArray {
public:
void run() {
- intrusive_ptr<Expression> expression = ExpressionConstant::create(Value(5));
+ intrusive_ptr<Expression> expression = ExpressionConstant::create(nullptr, Value(5));
// The constant is copied out as is.
assertBinaryEqual(constify(BSON_ARRAY(5)), toBsonArray(expression));
}
@@ -2342,7 +2352,7 @@ TEST(ExpressionObjectParse, ShouldAcceptEmptyObject) {
VariablesIdGenerator idGen;
VariablesParseState vps(&idGen);
auto object = ExpressionObject::parse(BSONObj(), vps);
- ASSERT_EQUALS(Value(Document{}), object->serialize(false));
+ ASSERT_VALUE_EQ(Value(Document{}), object->serialize(false));
}
TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) {
@@ -2355,7 +2365,7 @@ TEST(ExpressionObjectParse, ShouldAcceptLiteralsAsValues) {
vps);
auto expectedResult =
Value(Document{{"a", literal(5)}, {"b", literal("string")}, {"c", literal(BSONNULL)}});
- ASSERT_EQUALS(expectedResult, object->serialize(false));
+ ASSERT_VALUE_EQ(expectedResult, object->serialize(false));
}
TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) {
@@ -2363,7 +2373,7 @@ TEST(ExpressionObjectParse, ShouldAccept_idAsFieldName) {
VariablesParseState vps(&idGen);
auto object = ExpressionObject::parse(BSON("_id" << 5), vps);
auto expectedResult = Value(Document{{"_id", literal(5)}});
- ASSERT_EQUALS(expectedResult, object->serialize(false));
+ ASSERT_VALUE_EQ(expectedResult, object->serialize(false));
}
TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) {
@@ -2371,7 +2381,7 @@ TEST(ExpressionObjectParse, ShouldAcceptFieldNameContainingDollar) {
VariablesParseState vps(&idGen);
auto object = ExpressionObject::parse(BSON("a$b" << 5), vps);
auto expectedResult = Value(Document{{"a$b", literal(5)}});
- ASSERT_EQUALS(expectedResult, object->serialize(false));
+ ASSERT_VALUE_EQ(expectedResult, object->serialize(false));
}
TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) {
@@ -2381,7 +2391,7 @@ TEST(ExpressionObjectParse, ShouldAcceptNestedObjects) {
auto expectedResult =
Value(Document{{"a", Document{{"b", literal(1)}}},
{"c", Document{{"d", Document{{"e", literal(1)}, {"f", literal(1)}}}}}});
- ASSERT_EQUALS(expectedResult, object->serialize(false));
+ ASSERT_VALUE_EQ(expectedResult, object->serialize(false));
}
TEST(ExpressionObjectParse, ShouldAcceptArrays) {
@@ -2390,14 +2400,15 @@ TEST(ExpressionObjectParse, ShouldAcceptArrays) {
auto object = ExpressionObject::parse(fromjson("{a: [1, 2]}"), vps);
auto expectedResult =
Value(Document{{"a", vector<Value>{Value(literal(1)), Value(literal(2))}}});
- ASSERT_EQUALS(expectedResult, object->serialize(false));
+ ASSERT_VALUE_EQ(expectedResult, object->serialize(false));
}
TEST(ObjectParsing, ShouldAcceptExpressionAsValue) {
VariablesIdGenerator idGen;
VariablesParseState vps(&idGen);
auto object = ExpressionObject::parse(BSON("a" << BSON("$and" << BSONArray())), vps);
- ASSERT_EQ(object->serialize(false), Value(Document{{"a", Document{{"$and", BSONArray()}}}}));
+ ASSERT_VALUE_EQ(object->serialize(false),
+ Value(Document{{"a", Document{{"$and", BSONArray()}}}}));
}
//
@@ -2455,31 +2466,31 @@ TEST(ParseObject, ShouldRejectExpressionAsTheSecondField) {
TEST(ExpressionObjectEvaluate, EmptyObjectShouldEvaluateToEmptyDocument) {
auto object = ExpressionObject::create({});
- ASSERT_EQUALS(Value(Document()), object->evaluate(Document()));
- ASSERT_EQUALS(Value(Document()), object->evaluate(Document{{"a", 1}}));
- ASSERT_EQUALS(Value(Document()), object->evaluate(Document{{"_id", "ID"}}));
+ ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"a", 1}}));
+ ASSERT_VALUE_EQ(Value(Document()), object->evaluate(Document{{"_id", "ID"}}));
}
TEST(ExpressionObjectEvaluate, ShouldEvaluateEachField) {
auto object = ExpressionObject::create({{"a", makeConstant(1)}, {"b", makeConstant(5)}});
- ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document()));
- ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}}));
- ASSERT_EQUALS(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"_id", "ID"}}));
+ ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"a", 1}}));
+ ASSERT_VALUE_EQ(Value(Document{{"a", 1}, {"b", 5}}), object->evaluate(Document{{"_id", "ID"}}));
}
TEST(ExpressionObjectEvaluate, OrderOfFieldsInOutputShouldMatchOrderInSpecification) {
auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("a")},
{"b", ExpressionFieldPath::create("b")},
{"c", ExpressionFieldPath::create("c")}});
- ASSERT_EQUALS(Value(Document{{"a", "A"}, {"b", "B"}, {"c", "C"}}),
- object->evaluate(Document{{"c", "C"}, {"a", "A"}, {"b", "B"}, {"_id", "ID"}}));
+ ASSERT_VALUE_EQ(Value(Document{{"a", "A"}, {"b", "B"}, {"c", "C"}}),
+ object->evaluate(Document{{"c", "C"}, {"a", "A"}, {"b", "B"}, {"_id", "ID"}}));
}
TEST(ExpressionObjectEvaluate, ShouldRemoveFieldsThatHaveMissingValues) {
auto object = ExpressionObject::create(
{{"a", ExpressionFieldPath::create("a.b")}, {"b", ExpressionFieldPath::create("missing")}});
- ASSERT_EQUALS(Value(Document{}), object->evaluate(Document()));
- ASSERT_EQUALS(Value(Document{}), object->evaluate(Document{{"a", 1}}));
+ ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document{{"a", 1}}));
}
TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) {
@@ -2487,18 +2498,18 @@ TEST(ExpressionObjectEvaluate, ShouldEvaluateFieldsWithinNestedObject) {
{{"a",
ExpressionObject::create(
{{"b", makeConstant(1)}, {"c", ExpressionFieldPath::create("_id")}})}});
- ASSERT_EQUALS(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document()));
- ASSERT_EQUALS(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"}}}}),
- object->evaluate(Document{{"_id", "ID"}}));
+ ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}}}}), object->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document{{"a", Document{{"b", 1}, {"c", "ID"}}}}),
+ object->evaluate(Document{{"_id", "ID"}}));
}
TEST(ExpressionObjectEvaluate, ShouldEvaluateToEmptyDocumentIfAllFieldsAreMissing) {
auto object = ExpressionObject::create({{"a", ExpressionFieldPath::create("missing")}});
- ASSERT_EQUALS(Value(Document{}), object->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document{}), object->evaluate(Document()));
auto objectWithNestedObject = ExpressionObject::create({{"nested", object}});
- ASSERT_EQUALS(Value(Document{{"nested", Document{}}}),
- objectWithNestedObject->evaluate(Document()));
+ ASSERT_VALUE_EQ(Value(Document{{"nested", Document{}}}),
+ objectWithNestedObject->evaluate(Document()));
}
//
@@ -2542,7 +2553,7 @@ TEST(ExpressionObjectOptimizations, OptimizingAnObjectShouldOptimizeSubExpressio
auto expConstant =
dynamic_cast<ExpressionConstant*>(optimizedObject->getChildExpressions()[0].second.get());
ASSERT_TRUE(expConstant);
- ASSERT_EQ(expConstant->evaluate(Document()), Value(3));
+ ASSERT_VALUE_EQ(expConstant->evaluate(Document()), Value(3));
};
} // namespace Object
@@ -2897,7 +2908,7 @@ TEST(ParseExpression, ShouldRecognizeConstExpression) {
auto resultExpression = parseExpression(BSON("$const" << 5));
auto constExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get());
ASSERT_TRUE(constExpression);
- ASSERT_EQUALS(constExpression->serialize(false), Value(Document{{"$const", 5}}));
+ ASSERT_VALUE_EQ(constExpression->serialize(false), Value(Document{{"$const", 5}}));
}
TEST(ParseExpression, ShouldRejectUnknownExpression) {
@@ -2931,15 +2942,15 @@ TEST(ParseExpression, ShouldParseExpressionWithMultipleArguments) {
ASSERT_TRUE(strCaseCmpExpression);
vector<Value> arguments = {Value(Document{{"$const", "foo"}}),
Value(Document{{"$const", "FOO"}})};
- ASSERT_EQUALS(strCaseCmpExpression->serialize(false),
- Value(Document{{"$strcasecmp", arguments}}));
+ ASSERT_VALUE_EQ(strCaseCmpExpression->serialize(false),
+ Value(Document{{"$strcasecmp", arguments}}));
}
TEST(ParseExpression, ShouldParseExpressionWithNoArguments) {
auto resultExpression = parseExpression(BSON("$and" << BSONArray()));
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
- ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}}));
}
TEST(ParseExpression, ShouldParseExpressionWithOneArgument) {
@@ -2947,7 +2958,7 @@ TEST(ParseExpression, ShouldParseExpressionWithOneArgument) {
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
vector<Value> arguments = {Value(Document{{"$const", 1}})};
- ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
}
TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayForVariadicExpressions) {
@@ -2955,7 +2966,7 @@ TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayForVariadicExpressions) {
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
vector<Value> arguments = {Value(Document{{"$const", 1}})};
- ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
}
TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayAsSingleArgument) {
@@ -2963,7 +2974,7 @@ TEST(ParseExpression, ShouldAcceptArgumentWithoutArrayAsSingleArgument) {
auto notExpression = dynamic_cast<ExpressionNot*>(resultExpression.get());
ASSERT_TRUE(notExpression);
vector<Value> arguments = {Value(Document{{"$const", 1}})};
- ASSERT_EQUALS(notExpression->serialize(false), Value(Document{{"$not", arguments}}));
+ ASSERT_VALUE_EQ(notExpression->serialize(false), Value(Document{{"$not", arguments}}));
}
TEST(ParseExpression, ShouldAcceptObjectAsSingleArgument) {
@@ -2971,7 +2982,7 @@ TEST(ParseExpression, ShouldAcceptObjectAsSingleArgument) {
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
vector<Value> arguments = {Value(Document{{"$const", 1}})};
- ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
}
TEST(ParseExpression, ShouldAcceptObjectInsideArrayAsSingleArgument) {
@@ -2979,7 +2990,7 @@ TEST(ParseExpression, ShouldAcceptObjectInsideArrayAsSingleArgument) {
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
vector<Value> arguments = {Value(Document{{"$const", 1}})};
- ASSERT_EQUALS(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", arguments}}));
}
} // namespace Expression
@@ -3005,7 +3016,7 @@ TEST(ParseOperand, ShouldRecognizeFieldPath) {
<< "$field"));
auto fieldPathExpression = dynamic_cast<ExpressionFieldPath*>(resultExpression.get());
ASSERT_TRUE(fieldPathExpression);
- ASSERT_EQ(fieldPathExpression->serialize(false), Value("$field"));
+ ASSERT_VALUE_EQ(fieldPathExpression->serialize(false), Value("$field"));
}
TEST(ParseOperand, ShouldRecognizeStringLiteral) {
@@ -3013,7 +3024,7 @@ TEST(ParseOperand, ShouldRecognizeStringLiteral) {
<< "foo"));
auto constantExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get());
ASSERT_TRUE(constantExpression);
- ASSERT_EQ(constantExpression->serialize(false), Value(Document{{"$const", "foo"}}));
+ ASSERT_VALUE_EQ(constantExpression->serialize(false), Value(Document{{"$const", "foo"}}));
}
TEST(ParseOperand, ShouldRecognizeNestedArray) {
@@ -3022,21 +3033,21 @@ TEST(ParseOperand, ShouldRecognizeNestedArray) {
auto arrayExpression = dynamic_cast<ExpressionArray*>(resultExpression.get());
ASSERT_TRUE(arrayExpression);
vector<Value> expectedSerializedArray = {Value(Document{{"$const", "foo"}}), Value("$field")};
- ASSERT_EQ(arrayExpression->serialize(false), Value(expectedSerializedArray));
+ ASSERT_VALUE_EQ(arrayExpression->serialize(false), Value(expectedSerializedArray));
}
TEST(ParseOperand, ShouldRecognizeNumberLiteral) {
auto resultExpression = parseOperand(BSON("" << 5));
auto constantExpression = dynamic_cast<ExpressionConstant*>(resultExpression.get());
ASSERT_TRUE(constantExpression);
- ASSERT_EQ(constantExpression->serialize(false), Value(Document{{"$const", 5}}));
+ ASSERT_VALUE_EQ(constantExpression->serialize(false), Value(Document{{"$const", 5}}));
}
TEST(ParseOperand, ShouldRecognizeNestedExpression) {
auto resultExpression = parseOperand(BSON("" << BSON("$and" << BSONArray())));
auto andExpression = dynamic_cast<ExpressionAnd*>(resultExpression.get());
ASSERT_TRUE(andExpression);
- ASSERT_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}}));
+ ASSERT_VALUE_EQ(andExpression->serialize(false), Value(Document{{"$and", vector<Value>{}}}));
}
} // namespace Operand
@@ -3049,7 +3060,8 @@ Value sortSet(Value set) {
return Value(BSONNULL);
}
vector<Value> sortedSet = set.getArray();
- sort(sortedSet.begin(), sortedSet.end());
+ ValueComparator valueComparator;
+ sort(sortedSet.begin(), sortedSet.end(), valueComparator.getLessThan());
return Value(sortedSet);
}
@@ -3057,6 +3069,7 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
const Document spec = getSpec();
const Value args = spec["input"];
if (!spec["expected"].missing()) {
@@ -3068,11 +3081,12 @@ public:
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps);
+ expr->injectExpressionContext(expCtx);
Value result = expr->evaluate(Document());
if (result.getType() == Array) {
result = sortSet(result);
}
- if (result != expected) {
+ if (ValueComparator().evaluate(result != expected)) {
string errMsg = str::stream()
<< "for expression " << field.first.toString() << " with argument "
<< args.toString() << " full tree: " << expr->serialize(false).toString()
@@ -3096,6 +3110,7 @@ public:
// same
const intrusive_ptr<Expression> expr =
Expression::parseExpression(obj, vps);
+ expr->injectExpressionContext(expCtx);
expr->evaluate(Document());
},
UserException);
@@ -3372,11 +3387,13 @@ private:
return BSON("$strcasecmp" << BSON_ARRAY(b() << a()));
}
void assertResult(int expectedResult, const BSONObj& spec) {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObj = BSON("" << spec);
BSONElement specElement = specObj.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec), expressionToBson(expression));
ASSERT_EQUALS(BSON("" << expectedResult), toBson(expression->evaluate(Document())));
}
@@ -3498,11 +3515,13 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObj = BSON("" << spec());
BSONElement specElement = specObj.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document())));
}
@@ -3753,11 +3772,13 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObj = BSON("" << spec());
BSONElement specElement = specObj.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document())));
}
@@ -3810,11 +3831,13 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
BSONObj specObj = BSON("" << spec());
BSONElement specElement = specObj.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
ASSERT_EQUALS(BSON("" << expectedResult()), toBson(expression->evaluate(Document())));
}
@@ -3866,6 +3889,7 @@ class ExpectedResultBase {
public:
virtual ~ExpectedResultBase() {}
void run() {
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
const Document spec = getSpec();
const Value args = spec["input"];
if (!spec["expected"].missing()) {
@@ -3877,8 +3901,9 @@ public:
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
const intrusive_ptr<Expression> expr = Expression::parseExpression(obj, vps);
+ expr->injectExpressionContext(expCtx);
const Value result = expr->evaluate(Document());
- if (result != expected) {
+ if (ValueComparator().evaluate(result != expected)) {
string errMsg = str::stream()
<< "for expression " << field.first.toString() << " with argument "
<< args.toString() << " full tree: " << expr->serialize(false).toString()
@@ -3902,6 +3927,7 @@ public:
// same
const intrusive_ptr<Expression> expr =
Expression::parseExpression(obj, vps);
+ expr->injectExpressionContext(expCtx);
expr->evaluate(Document());
},
UserException);
diff --git a/src/mongo/db/pipeline/lookup_set_cache.h b/src/mongo/db/pipeline/lookup_set_cache.h
index 3150d7bc1af..7ca74f90964 100644
--- a/src/mongo/db/pipeline/lookup_set_cache.h
+++ b/src/mongo/db/pipeline/lookup_set_cache.h
@@ -41,6 +41,7 @@
#include "mongo/bson/bsonobj.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/stdx/functional.h"
namespace mongo {
@@ -56,6 +57,9 @@ using boost::multi_index::indexed_by;
* limit, but includes the ability to evict down to both a specific number of elements, and down to
* a specific amount of memory. Memory usage includes only the size of the elements in the cache at
* the time of insertion, not the overhead incurred by the data structures in use.
+ *
+ * TODO SERVER-23349: This class must make all comparisons of user data using the aggregation
+ * operation's collation.
*/
class LookupSetCache {
public:
@@ -167,9 +171,12 @@ private:
// a container of std::pair<Value, BSONObjSet>, that is both sequenced, and has a unique
// index on the Value. From this, we are able to evict the least-recently-used member, and
// maintain key uniqueness.
- using IndexedContainer = multi_index_container<
- Cached,
- indexed_by<sequenced<>, hashed_unique<member<Cached, Value, &Cached::first>, Value::Hash>>>;
+ using IndexedContainer =
+ multi_index_container<Cached,
+ indexed_by<sequenced<>,
+ hashed_unique<member<Cached, Value, &Cached::first>,
+ Value::Hash,
+ ValueComparator::EqualTo>>>;
IndexedContainer _container;
diff --git a/src/mongo/db/pipeline/parsed_aggregation_projection.h b/src/mongo/db/pipeline/parsed_aggregation_projection.h
index 9149ff7a53c..a3420fdb954 100644
--- a/src/mongo/db/pipeline/parsed_aggregation_projection.h
+++ b/src/mongo/db/pipeline/parsed_aggregation_projection.h
@@ -28,13 +28,15 @@
#pragma once
+#include <boost/intrusive_ptr.hpp>
#include <memory>
namespace mongo {
class BSONObj;
-struct DepsTracker;
class Document;
+struct DepsTracker;
+struct ExpressionContext;
namespace parsed_aggregation_projection {
@@ -80,6 +82,11 @@ public:
virtual void optimize() {}
/**
+ * Inject the ExpressionContext into any expressions contained within this projection.
+ */
+ virtual void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {}
+
+ /**
* Add any dependencies needed by this projection or any sub-expressions to 'deps'.
*/
virtual void addDependencies(DepsTracker* deps) const {}
diff --git a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp
index 7703ac3960b..5768b46882a 100644
--- a/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp
+++ b/src/mongo/db/pipeline/parsed_exclusion_projection_test.cpp
@@ -39,6 +39,7 @@
#include "mongo/bson/json.h"
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/unittest/death_test.h"
#include "mongo/unittest/unittest.h"
@@ -83,17 +84,17 @@ TEST(ExclusionProjection, ShouldSerializeToEquivalentProjection) {
// fields is subject to change.
auto serialization = exclusion.serialize();
ASSERT_EQ(serialization.size(), 4UL);
- ASSERT_EQ(serialization["a"], Value(false));
- ASSERT_EQ(serialization["_id"], Value(false));
+ ASSERT_VALUE_EQ(serialization["a"], Value(false));
+ ASSERT_VALUE_EQ(serialization["_id"], Value(false));
ASSERT_EQ(serialization["b"].getType(), BSONType::Object);
ASSERT_EQ(serialization["b"].getDocument().size(), 2UL);
- ASSERT_EQ(serialization["b"].getDocument()["c"], Value(false));
- ASSERT_EQ(serialization["b"].getDocument()["d"], Value(false));
+ ASSERT_VALUE_EQ(serialization["b"].getDocument()["c"], Value(false));
+ ASSERT_VALUE_EQ(serialization["b"].getDocument()["d"], Value(false));
ASSERT_EQ(serialization["x"].getType(), BSONType::Object);
ASSERT_EQ(serialization["x"].getDocument().size(), 1UL);
- ASSERT_EQ(serialization["x"].getDocument()["y"], Value(false));
+ ASSERT_VALUE_EQ(serialization["x"].getDocument()["y"], Value(false));
}
TEST(ExclusionProjection, ShouldNotAddAnyDependencies) {
@@ -127,22 +128,22 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeTopLevelField) {
// More than one field in document.
auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}});
auto expectedResult = Document{{"b", 2}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is the only field in the document.
result = exclusion.applyProjection(Document{{"a", 1}});
expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is not present in the document.
result = exclusion.applyProjection(Document{{"c", 1}});
expectedResult = Document{{"c", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// There are no fields in the document.
result = exclusion.applyProjection(Document{});
expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) {
@@ -152,7 +153,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldCoerceNumericsToBools) {
auto result = exclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}, {"c", 3}});
auto expectedResult = Document{{"_id", "ID"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) {
@@ -160,7 +161,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldPreserveOrderOfExistingFields) {
exclusion.parse(BSON("second" << false));
auto result = exclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}});
auto expectedResult = Document{{"first", 0}, {"third", 2}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) {
@@ -168,7 +169,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) {
exclusion.parse(BSON("a" << false));
auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}});
auto expectedResult = Document{{"b", 2}, {"_id", "ID"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) {
@@ -176,7 +177,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) {
exclusion.parse(BSON("a" << false << "_id" << false));
auto result = exclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}});
auto expectedResult = Document{{"b", 2}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
//
@@ -189,7 +190,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSubFieldsOfId) {
auto result = exclusion.applyProjection(
Document{{"_id", Document{{"x", 1}, {"y", 2}, {"z", 3}}}, {"a", 1}});
auto expectedResult = Document{{"_id", Document{{"z", 3}}}, {"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc) {
@@ -199,22 +200,22 @@ TEST(ExclusionProjectionExecutionTest, ShouldExcludeSimpleDottedFieldFromSubDoc)
// More than one field in sub document.
auto result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}});
auto expectedResult = Document{{"a", Document{{"c", 2}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is the only field in the sub document.
result = exclusion.applyProjection(Document{{"a", Document{{"b", 1}}}});
expectedResult = Document{{"a", Document{}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is not present in the sub document.
result = exclusion.applyProjection(Document{{"a", Document{{"c", 1}}}});
expectedResult = Document{{"a", Document{{"c", 1}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// There are no fields in sub document.
result = exclusion.applyProjection(Document{{"a", Document{}}});
expectedResult = Document{{"a", Document{}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFieldDoesNotExist) {
@@ -224,12 +225,12 @@ TEST(ExclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedExcludedFiel
// Should not add the path if it doesn't exist.
auto result = exclusion.applyProjection(Document{});
auto expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Should not replace non-documents with documents.
result = exclusion.applyProjection(Document{{"sub", "notADocument"}});
expectedResult = Document{{"sub", "notADocument"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementInArray) {
@@ -252,7 +253,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldApplyDottedExclusionToEachElementIn
Value(vector<Value>{Value(1), Value(Document{{"c", 1}})})};
auto result = exclusion.applyProjection(Document{{"a", nestedValues}});
auto expectedResult = Document{{"a", expectedNestedValues}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) {
@@ -263,7 +264,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) {
auto result = exclusion.applyProjection(
Document{{"a", Document{{"b", 1}, {"c", 2}, {"d", 3}, {"e", 4}, {"f", 5}}}});
auto expectedResult = Document{{"a", Document{{"f", 5}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc) {
@@ -279,7 +280,7 @@ TEST(ExclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc)
MutableDocument expectedDoc(Document{{"_id", "ID"}});
expectedDoc.copyMetaDataFrom(inputDoc);
- ASSERT_EQ(result, expectedDoc.freeze());
+ ASSERT_DOCUMENT_EQ(result, expectedDoc.freeze());
}
} // namespace
diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp
index 69d600f5551..4842500cd90 100644
--- a/src/mongo/db/pipeline/parsed_inclusion_projection.cpp
+++ b/src/mongo/db/pipeline/parsed_inclusion_projection.cpp
@@ -54,6 +54,16 @@ void InclusionNode::optimize() {
}
}
+void InclusionNode::injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) {
+ for (auto&& expressionIt : _expressions) {
+ expressionIt.second->injectExpressionContext(expCtx);
+ }
+
+ for (auto&& childPair : _children) {
+ childPair.second->injectExpressionContext(expCtx);
+ }
+}
+
void InclusionNode::serialize(MutableDocument* output, bool explain) const {
// Always put "_id" first if it was included (implicitly or explicitly).
if (_inclusions.find("_id") != _inclusions.end()) {
diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection.h b/src/mongo/db/pipeline/parsed_inclusion_projection.h
index 7236d8f85eb..d762d6d7ede 100644
--- a/src/mongo/db/pipeline/parsed_inclusion_projection.h
+++ b/src/mongo/db/pipeline/parsed_inclusion_projection.h
@@ -33,6 +33,7 @@
#include <unordered_set>
#include "mongo/db/pipeline/expression.h"
+#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/parsed_aggregation_projection.h"
#include "mongo/stdx/memory.h"
@@ -118,6 +119,8 @@ public:
return _pathToNode;
}
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx);
+
private:
// Helpers for the Document versions above. These will apply the transformation recursively to
// each element of any arrays, and ensure non-documents are handled appropriately.
@@ -202,6 +205,10 @@ public:
_root->optimize();
}
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx) final {
+ _root->injectExpressionContext(expCtx);
+ }
+
void addDependencies(DepsTracker* deps) const final {
_root->addDependencies(deps);
}
diff --git a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp
index f4c1d6a4a0e..b2833638543 100644
--- a/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp
+++ b/src/mongo/db/pipeline/parsed_inclusion_projection_test.cpp
@@ -37,6 +37,7 @@
#include "mongo/bson/json.h"
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/value.h"
#include "mongo/unittest/unittest.h"
@@ -129,8 +130,8 @@ TEST(InclusionProjection, ShouldSerializeToEquivalentProjection) {
"{_id: true, a: {$add: [\"$a\", {$const: 2}]}, b: {d: true}, x: {y: {$const: 4}}}"));
// Should be the same if we're serializing for explain or for internal use.
- ASSERT_EQ(expectedSerialization, inclusion.serialize(false));
- ASSERT_EQ(expectedSerialization, inclusion.serialize(true));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true));
}
TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) {
@@ -141,8 +142,8 @@ TEST(InclusionProjection, ShouldSerializeExplicitExclusionOfId) {
auto expectedSerialization = Document{{"_id", false}, {"a", true}};
// Should be the same if we're serializing for explain or for internal use.
- ASSERT_EQ(expectedSerialization, inclusion.serialize(false));
- ASSERT_EQ(expectedSerialization, inclusion.serialize(true));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true));
}
@@ -155,8 +156,8 @@ TEST(InclusionProjection, ShouldOptimizeTopLevelExpressions) {
auto expectedSerialization = Document{{"_id", true}, {"a", Document{{"$const", 3}}}};
// Should be the same if we're serializing for explain or for internal use.
- ASSERT_EQ(expectedSerialization, inclusion.serialize(false));
- ASSERT_EQ(expectedSerialization, inclusion.serialize(true));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true));
}
TEST(InclusionProjection, ShouldOptimizeNestedExpressions) {
@@ -169,8 +170,8 @@ TEST(InclusionProjection, ShouldOptimizeNestedExpressions) {
Document{{"_id", true}, {"a", Document{{"b", Document{{"$const", 3}}}}}};
// Should be the same if we're serializing for explain or for internal use.
- ASSERT_EQ(expectedSerialization, inclusion.serialize(false));
- ASSERT_EQ(expectedSerialization, inclusion.serialize(true));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(false));
+ ASSERT_DOCUMENT_EQ(expectedSerialization, inclusion.serialize(true));
}
//
@@ -184,22 +185,22 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeTopLevelField) {
// More than one field in document.
auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}});
auto expectedResult = Document{{"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is the only field in the document.
result = inclusion.applyProjection(Document{{"a", 1}});
expectedResult = Document{{"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is not present in the document.
result = inclusion.applyProjection(Document{{"c", 1}});
expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// There are no fields in the document.
result = inclusion.applyProjection(Document{});
expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) {
@@ -207,12 +208,12 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedTopLevelField) {
inclusion.parse(BSON("newField" << wrapInLiteral("computedVal")));
auto result = inclusion.applyProjection(Document{});
auto expectedResult = Document{{"newField", "computedVal"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Computed field should replace existing field.
result = inclusion.applyProjection(Document{{"newField", "preExisting"}});
expectedResult = Document{{"newField", "computedVal"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedFields) {
@@ -220,7 +221,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyBothInclusionsAndComputedField
inclusion.parse(BSON("a" << true << "newField" << wrapInLiteral("computedVal")));
auto result = inclusion.applyProjection(Document{{"a", 1}});
auto expectedResult = Document{{"a", 1}, {"newField", "computedVal"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) {
@@ -228,7 +229,7 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeFieldsInOrderOfInputDoc) {
inclusion.parse(BSON("first" << true << "second" << true << "third" << true));
auto inputDoc = Document{{"second", 1}, {"first", 0}, {"third", 2}};
auto result = inclusion.applyProjection(inputDoc);
- ASSERT_EQ(result, inputDoc);
+ ASSERT_DOCUMENT_EQ(result, inputDoc);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified) {
@@ -237,7 +238,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsInOrderSpecified
<< wrapInLiteral("SECOND")));
auto result = inclusion.applyProjection(Document{{"first", 0}, {"second", 1}, {"third", 2}});
auto expectedResult = Document{{"firstComputed", "FIRST"}, {"secondComputed", "SECOND"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) {
@@ -245,12 +246,12 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeId) {
inclusion.parse(BSON("a" << true));
auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}});
auto expectedResult = Document{{"_id", "ID"}, {"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Should leave the "_id" in the same place as in the original document.
result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}});
expectedResult = Document{{"a", 1}, {"_id", "ID"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFields) {
@@ -258,7 +259,7 @@ TEST(InclusionProjectionExecutionTest, ShouldImplicitlyIncludeIdWithComputedFiel
inclusion.parse(BSON("newField" << wrapInLiteral("computedVal")));
auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}});
auto expectedResult = Document{{"_id", "ID"}, {"newField", "computedVal"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) {
@@ -266,7 +267,7 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeIdIfExplicitlyIncluded) {
inclusion.parse(BSON("a" << true << "_id" << true << "b" << true));
auto result = inclusion.applyProjection(Document{{"_id", "ID"}, {"a", 1}, {"b", 2}, {"c", 3}});
auto expectedResult = Document{{"_id", "ID"}, {"a", 1}, {"b", 2}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) {
@@ -274,7 +275,7 @@ TEST(InclusionProjectionExecutionTest, ShouldExcludeIdIfExplicitlyExcluded) {
inclusion.parse(BSON("a" << true << "_id" << false));
auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}});
auto expectedResult = Document{{"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) {
@@ -282,7 +283,7 @@ TEST(InclusionProjectionExecutionTest, ShouldReplaceIdWithComputedId) {
inclusion.parse(BSON("_id" << wrapInLiteral("newId")));
auto result = inclusion.applyProjection(Document{{"a", 1}, {"b", 2}, {"_id", "ID"}});
auto expectedResult = Document{{"_id", "newId"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
//
@@ -296,22 +297,22 @@ TEST(InclusionProjectionExecutionTest, ShouldIncludeSimpleDottedFieldFromSubDoc)
// More than one field in sub document.
auto result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}, {"c", 2}}}});
auto expectedResult = Document{{"a", Document{{"b", 1}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is the only field in the sub document.
result = inclusion.applyProjection(Document{{"a", Document{{"b", 1}}}});
expectedResult = Document{{"a", Document{{"b", 1}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is not present in the sub document.
result = inclusion.applyProjection(Document{{"a", Document{{"c", 1}}}});
expectedResult = Document{{"a", Document{}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// There are no fields in sub document.
result = inclusion.applyProjection(Document{{"a", Document{}}});
expectedResult = Document{{"a", Document{}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFieldDoesNotExist) {
@@ -321,12 +322,12 @@ TEST(InclusionProjectionExecutionTest, ShouldNotCreateSubDocIfDottedIncludedFiel
// Should not add the path if it doesn't exist.
auto result = inclusion.applyProjection(Document{});
auto expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Should not replace the first part of the path if that part exists.
result = inclusion.applyProjection(Document{{"sub", "notADocument"}});
expectedResult = Document{};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementInArray) {
@@ -350,7 +351,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyDottedInclusionToEachElementIn
Value(vector<Value>{Value(), Value(Document{})})};
auto result = inclusion.applyProjection(Document{{"a", nestedValues}});
auto expectedResult = Document{{"a", expectedNestedValues}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument) {
@@ -360,17 +361,17 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToSubDocument
// Other fields exist in sub document, one of which is the specified field.
auto result = inclusion.applyProjection(Document{{"sub", Document{{"target", 1}, {"c", 2}}}});
auto expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Specified field is not present in the sub document.
result = inclusion.applyProjection(Document{{"sub", Document{{"c", 1}}}});
expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// There are no fields in sub document.
result = inclusion.applyProjection(Document{{"sub", Document{}}});
expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDoesntExist) {
@@ -380,11 +381,11 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateSubDocIfDottedComputedFieldDo
// Should add the path if it doesn't exist.
auto result = inclusion.applyProjection(Document{});
auto expectedResult = Document{{"sub", Document{{"target", "computedVal"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Should replace non-documents with documents.
result = inclusion.applyProjection(Document{{"sub", "notADocument"}});
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayToComputedField) {
@@ -395,11 +396,11 @@ TEST(InclusionProjectionExecutionTest, ShouldCreateNestedSubDocumentsAllTheWayTo
auto result = inclusion.applyProjection(Document{});
auto expectedResult =
Document{{"a", Document{{"b", Document{{"c", Document{{"d", "computedVal"}}}}}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// Should replace non-documents with documents.
result = inclusion.applyProjection(Document{{"a", Document{{"b", "other"}}}});
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElementInArray) {
@@ -421,7 +422,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAddComputedDottedFieldToEachElement
Value(Document{{"b", "COMPUTED"}})})};
auto result = inclusion.applyProjection(Document{{"a", nestedValues}});
auto expectedResult = Document{{"a", expectedNestedValues}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachElementInArray) {
@@ -448,7 +449,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyInclusionsAndAdditionsToEachEl
Value(Document{{"inc", 1}, {"comp", "COMPUTED"}})})};
auto result = inclusion.applyProjection(Document{{"a", nestedValues}});
auto expectedResult = Document{{"a", expectedNestedValues}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) {
@@ -456,7 +457,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAddOrIncludeSubFieldsOfId) {
inclusion.parse(BSON("_id.X" << true << "_id.Z" << wrapInLiteral("NEW")));
auto result = inclusion.applyProjection(Document{{"_id", Document{{"X", 1}, {"Y", 2}}}});
auto expectedResult = Document{{"_id", Document{{"X", 1}, {"Z", "NEW"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) {
@@ -479,7 +480,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAllowMixedNestedAndDottedFields) {
{"X", "X"},
{"Y", "Y"},
{"Z", "Z"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpecified) {
@@ -487,7 +488,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyNestedComputedFieldsInOrderSpe
inclusion.parse(BSON("a" << wrapInLiteral("FIRST") << "b.c" << wrapInLiteral("SECOND")));
auto result = inclusion.applyProjection(Document{});
auto expectedResult = Document{{"a", "FIRST"}, {"b", Document{{"c", "SECOND"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusions) {
@@ -495,10 +496,10 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio
inclusion.parse(BSON("b.c" << wrapInLiteral("NEW") << "a" << true));
auto result = inclusion.applyProjection(Document{{"a", 1}});
auto expectedResult = Document{{"a", 1}, {"b", Document{{"c", "NEW"}}}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
result = inclusion.applyProjection(Document{{"a", 1}, {"b", 4}});
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
// In this case, the field 'b' shows up first and has a nested inclusion or computed field. Even
// though it is a computed field, it will appear first in the output document. This is
@@ -506,7 +507,7 @@ TEST(InclusionProjectionExecutionTest, ShouldApplyComputedFieldsAfterAllInclusio
// recursively to each sub-document.
result = inclusion.applyProjection(Document{{"b", 4}, {"a", 1}});
expectedResult = Document{{"b", Document{{"c", "NEW"}}}, {"a", 1}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppearAfterInclusions) {
@@ -514,10 +515,10 @@ TEST(InclusionProjectionExecutionTest, ComputedFieldReplacingExistingShouldAppea
inclusion.parse(BSON("b" << wrapInLiteral("NEW") << "a" << true));
auto result = inclusion.applyProjection(Document{{"b", 1}, {"a", 1}});
auto expectedResult = Document{{"a", 1}, {"b", "NEW"}};
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
result = inclusion.applyProjection(Document{{"a", 1}, {"b", 4}});
- ASSERT_EQ(result, expectedResult);
+ ASSERT_DOCUMENT_EQ(result, expectedResult);
}
//
@@ -537,7 +538,7 @@ TEST(InclusionProjectionExecutionTest, ShouldAlwaysKeepMetadataFromOriginalDoc)
MutableDocument expectedDoc(inputDoc);
expectedDoc.copyMetaDataFrom(inputDoc);
- ASSERT_EQ(result, expectedDoc.freeze());
+ ASSERT_DOCUMENT_EQ(result, expectedDoc.freeze());
}
} // namespace
diff --git a/src/mongo/db/pipeline/pipeline.cpp b/src/mongo/db/pipeline/pipeline.cpp
index d954991ce22..45b7ee58e51 100644
--- a/src/mongo/db/pipeline/pipeline.cpp
+++ b/src/mongo/db/pipeline/pipeline.cpp
@@ -221,11 +221,11 @@ void Pipeline::reattachToOperationContext(OperationContext* opCtx) {
}
}
-void Pipeline::setCollator(std::unique_ptr<CollatorInterface> collator) {
- pCtx->collator = std::move(collator);
-
- // TODO SERVER-23349: If the pipeline has any DocumentSourceMatch sources, ask them to
- // re-parse their predicates.
+void Pipeline::injectExpressionContext(const intrusive_ptr<ExpressionContext>& expCtx) {
+ pCtx = expCtx;
+ for (auto&& stage : _sources) {
+ stage->injectExpressionContext(pCtx);
+ }
}
intrusive_ptr<Pipeline> Pipeline::splitForSharded() {
@@ -461,4 +461,5 @@ DepsTracker Pipeline::getDependencies(DepsTracker::MetadataAvailable metadataAva
return deps;
}
+
} // namespace mongo
diff --git a/src/mongo/db/pipeline/pipeline.h b/src/mongo/db/pipeline/pipeline.h
index d577b5fc151..531411fae60 100644
--- a/src/mongo/db/pipeline/pipeline.h
+++ b/src/mongo/db/pipeline/pipeline.h
@@ -107,15 +107,6 @@ public:
void reattachToOperationContext(OperationContext* opCtx);
/**
- * Sets the collator on this Pipeline. parseCommand() will return a Pipeline with its collator
- * already set from the parsed request (if applicable), but this setter method can be used to
- * later override the Pipeline's original collator.
- *
- * The Pipeline's collator can be retrieved with getContext()->collator.
- */
- void setCollator(std::unique_ptr<CollatorInterface> collator);
-
- /**
Split the current Pipeline into a Pipeline for each shard, and
a Pipeline that combines the results within mongos.
@@ -142,6 +133,12 @@ public:
void optimizePipeline();
/**
+ * Propagates a reference to the ExpressionContext to all of the pipeline's contained stages and
+ * expressions.
+ */
+ void injectExpressionContext(const boost::intrusive_ptr<ExpressionContext>& expCtx);
+
+ /**
* Returns any other collections involved in the pipeline in addition to the collection the
* aggregation is run on.
*/
diff --git a/src/mongo/db/pipeline/pipeline_d.cpp b/src/mongo/db/pipeline/pipeline_d.cpp
index b8e7c1bcd1a..3ebc683676f 100644
--- a/src/mongo/db/pipeline/pipeline_d.cpp
+++ b/src/mongo/db/pipeline/pipeline_d.cpp
@@ -226,8 +226,8 @@ StatusWith<std::unique_ptr<PlanExecutor>> attemptToGetExecutor(
//
// If pipeline has a null collator (representing the "simple" collation), we simply set the
// collation option to the original user BSON.
- qr->setCollation(pExpCtx->collator ? pExpCtx->collator->getSpec().toBSON()
- : pExpCtx->collation);
+ qr->setCollation(pExpCtx->getCollator() ? pExpCtx->getCollator()->getSpec().toBSON()
+ : pExpCtx->collation);
const ExtensionsCallbackReal extensionsCallback(pExpCtx->opCtx, &pExpCtx->ns);
diff --git a/src/mongo/db/pipeline/pipeline_test.cpp b/src/mongo/db/pipeline/pipeline_test.cpp
index 4cdf064ace6..864680bcbb9 100644
--- a/src/mongo/db/pipeline/pipeline_test.cpp
+++ b/src/mongo/db/pipeline/pipeline_test.cpp
@@ -37,6 +37,7 @@
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document.h"
#include "mongo/db/pipeline/document_source.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/field_path.h"
#include "mongo/db/pipeline/pipeline.h"
@@ -94,8 +95,9 @@ public:
auto outputPipe = uassertStatusOK(Pipeline::parse(request.getPipeline(), ctx));
outputPipe->optimizePipeline();
- ASSERT_EQUALS(Value(outputPipe->writeExplainOps()), Value(outputPipeExpected["pipeline"]));
- ASSERT_EQUALS(Value(outputPipe->serialize()), Value(serializePipeExpected["pipeline"]));
+ ASSERT_VALUE_EQ(Value(outputPipe->writeExplainOps()),
+ Value(outputPipeExpected["pipeline"]));
+ ASSERT_VALUE_EQ(Value(outputPipe->serialize()), Value(serializePipeExpected["pipeline"]));
}
virtual ~Base() {}
@@ -744,8 +746,8 @@ public:
shardPipe = mergePipe->splitForSharded();
ASSERT(shardPipe != nullptr);
- ASSERT_EQUALS(Value(shardPipe->writeExplainOps()), Value(shardPipeExpected["pipeline"]));
- ASSERT_EQUALS(Value(mergePipe->writeExplainOps()), Value(mergePipeExpected["pipeline"]));
+ ASSERT_VALUE_EQ(Value(shardPipe->writeExplainOps()), Value(shardPipeExpected["pipeline"]));
+ ASSERT_VALUE_EQ(Value(mergePipe->writeExplainOps()), Value(mergePipeExpected["pipeline"]));
}
virtual ~Base() {}
@@ -1092,9 +1094,9 @@ TEST(PipelineInitialSource, ParseCollation) {
ASSERT_OK(request.getStatus());
intrusive_ptr<ExpressionContext> ctx = new ExpressionContext(opCtx.get(), request.getValue());
- ASSERT(ctx->collator.get());
+ ASSERT(ctx->getCollator());
CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
- ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->collator.get(), &collator));
+ ASSERT_TRUE(CollatorInterface::collatorsMatch(ctx->getCollator(), &collator));
}
namespace Dependencies {
diff --git a/src/mongo/db/pipeline/tee_buffer_test.cpp b/src/mongo/db/pipeline/tee_buffer_test.cpp
index 9fb2b56f1fc..1f218ea1d89 100644
--- a/src/mongo/db/pipeline/tee_buffer_test.cpp
+++ b/src/mongo/db/pipeline/tee_buffer_test.cpp
@@ -31,6 +31,7 @@
#include "mongo/db/pipeline/tee_buffer.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
@@ -63,7 +64,7 @@ TEST(TeeBufferTest, ShouldProvideIteratorOverSingleDocument) {
// Should be able to establish an iterator and get the document back.
auto it = teeBuffer->begin();
ASSERT(it != teeBuffer->end());
- ASSERT_EQ(*it, inputDoc);
+ ASSERT_DOCUMENT_EQ(*it, inputDoc);
++it;
ASSERT(it == teeBuffer->end());
}
@@ -77,10 +78,10 @@ TEST(TeeBufferTest, ShouldProvideIteratorOverTwoDocuments) {
auto it = teeBuffer->begin();
ASSERT(it != teeBuffer->end());
- ASSERT_EQ(*it, inputDocs.front());
+ ASSERT_DOCUMENT_EQ(*it, inputDocs.front());
++it;
ASSERT(it != teeBuffer->end());
- ASSERT_EQ(*it, inputDocs.back());
+ ASSERT_DOCUMENT_EQ(*it, inputDocs.back());
++it;
ASSERT(it == teeBuffer->end());
}
@@ -97,18 +98,18 @@ TEST(TeeBufferTest, ShouldBeAbleToProvideMultipleIteratorsOverTheSameInputs) {
// Advance both once.
ASSERT(firstIt != teeBuffer->end());
- ASSERT_EQ(*firstIt, inputDocs.front());
+ ASSERT_DOCUMENT_EQ(*firstIt, inputDocs.front());
++firstIt;
ASSERT(secondIt != teeBuffer->end());
- ASSERT_EQ(*secondIt, inputDocs.front());
+ ASSERT_DOCUMENT_EQ(*secondIt, inputDocs.front());
++secondIt;
// Advance them both again.
ASSERT(firstIt != teeBuffer->end());
- ASSERT_EQ(*firstIt, inputDocs.back());
+ ASSERT_DOCUMENT_EQ(*firstIt, inputDocs.back());
++firstIt;
ASSERT(secondIt != teeBuffer->end());
- ASSERT_EQ(*secondIt, inputDocs.back());
+ ASSERT_DOCUMENT_EQ(*secondIt, inputDocs.back());
++secondIt;
// Assert they've both reached the end.
diff --git a/src/mongo/db/pipeline/value.cpp b/src/mongo/db/pipeline/value.cpp
index 4334af8260c..2d76ebe2bb6 100644
--- a/src/mongo/db/pipeline/value.cpp
+++ b/src/mongo/db/pipeline/value.cpp
@@ -638,7 +638,9 @@ inline static int cmp(const T& left, const T& right) {
}
}
-int Value::compare(const Value& rL, const Value& rR) {
+int Value::compare(const Value& rL,
+ const Value& rR,
+ const StringData::ComparatorInterface* stringComparator) {
// Note, this function needs to behave identically to BSON's compareElementValues().
// Additionally, any changes here must be replicated in hash_combine().
BSONType lType = rL.getType();
@@ -739,13 +741,20 @@ int Value::compare(const Value& rL, const Value& rR) {
case jstOID:
return memcmp(rL._storage.oid, rR._storage.oid, OID::kOIDSize);
+ case String: {
+ if (!stringComparator) {
+ return rL.getStringData().compare(rR.getStringData());
+ }
+
+ return stringComparator->compare(rL.getStringData(), rR.getStringData());
+ }
+
case Code:
case Symbol:
- case String:
return rL.getStringData().compare(rR.getStringData());
case Object:
- return Document::compare(rL.getDocument(), rR.getDocument());
+ return Document::compare(rL.getDocument(), rR.getDocument(), stringComparator);
case Array: {
const vector<Value>& lArr = rL.getArray();
@@ -754,7 +763,7 @@ int Value::compare(const Value& rL, const Value& rR) {
const size_t elems = std::min(lArr.size(), rArr.size());
for (size_t i = 0; i < elems; i++) {
// compare the two corresponding elements
- ret = Value::compare(lArr[i], rArr[i]);
+ ret = Value::compare(lArr[i], rArr[i], stringComparator);
if (ret)
return ret; // values are unequal
}
diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h
index bda0723d471..e151e58c813 100644
--- a/src/mongo/db/pipeline/value.h
+++ b/src/mongo/db/pipeline/value.h
@@ -28,7 +28,7 @@
#pragma once
-
+#include "mongo/base/string_data.h"
#include "mongo/db/pipeline/value_internal.h"
#include "mongo/platform/unordered_set.h"
@@ -57,6 +57,28 @@ class BSONElement;
*/
class Value {
public:
+ /**
+ * Operator overloads for relops return a DeferredComparison which can subsequently be evaluated
+ * by a ValueComparator.
+ */
+ struct DeferredComparison {
+ enum class Type {
+ kLT,
+ kLTE,
+ kEQ,
+ kGT,
+ kGTE,
+ kNE,
+ };
+
+ DeferredComparison(Type type, const Value& lhs, const Value& rhs)
+ : type(type), lhs(lhs), rhs(rhs) {}
+
+ Type type;
+ const Value& lhs;
+ const Value& rhs;
+ };
+
/** Construct a Value
*
* All types not listed will be rejected rather than converted (see private for why)
@@ -200,28 +222,54 @@ public:
time_t coerceToTimeT() const;
tm coerceToTm() const; // broken-out time struct (see man gmtime)
+ //
+ // Comparison API.
+ //
+ // Value instances can be compared either using Value::compare() or via operator overloads.
+ // Most callers should prefer operator overloads. Note that the operator overloads return a
+ // DeferredComparison, which must be subsequently evaluated by a ValueComparator. See
+ // value_comparator.h for details.
+ //
- /** Compare two Values.
+ /**
+ * Compare two Values. Most Values should prefer to use ValueComparator instead. See
+ * value_comparator.h for details.
+ *
+ * Pass a non-null StringData::ComparatorInterface if special string comparison semantics are
+ * required. If the comparator is null, then a simple binary compare is used for strings. This
+ * comparator is only used for string *values*; field names are always compared using simple
+ * binary compare.
+ *
* @returns an integer less than zero, zero, or an integer greater than
* zero, depending on whether lhs < rhs, lhs == rhs, or lhs > rhs
* Warning: may return values other than -1, 0, or 1
*/
- static int compare(const Value& lhs, const Value& rhs);
-
- friend bool operator==(const Value& v1, const Value& v2) {
- if (v1._storage.identical(v2._storage)) {
- // Simple case
- return true;
- }
- return (Value::compare(v1, v2) == 0);
+ static int compare(const Value& lhs,
+ const Value& rhs,
+ const StringData::ComparatorInterface* stringComparator = nullptr);
+
+ friend DeferredComparison operator==(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kEQ, lhs, rhs);
+ }
+
+ friend DeferredComparison operator!=(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kNE, lhs, rhs);
+ }
+
+ friend DeferredComparison operator<(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kLT, lhs, rhs);
+ }
+
+ friend DeferredComparison operator<=(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kLTE, lhs, rhs);
}
- friend bool operator!=(const Value& v1, const Value& v2) {
- return !(v1 == v2);
+ friend DeferredComparison operator>(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kGT, lhs, rhs);
}
- friend bool operator<(const Value& lhs, const Value& rhs) {
- return (Value::compare(lhs, rhs) < 0);
+ friend DeferredComparison operator>=(const Value& lhs, const Value& rhs) {
+ return DeferredComparison(DeferredComparison::Type::kGTE, lhs, rhs);
}
/// This is for debugging, logging, etc. See getString() for how to extract a string.
@@ -289,8 +337,6 @@ private:
};
static_assert(sizeof(Value) == 16, "sizeof(Value) == 16");
-typedef unordered_set<Value, Value::Hash> ValueSet;
-
inline void swap(mongo::Value& lhs, mongo::Value& rhs) {
lhs.swap(rhs);
}
diff --git a/src/mongo/db/pipeline/value_comparator.cpp b/src/mongo/db/pipeline/value_comparator.cpp
new file mode 100644
index 00000000000..a20a9460b6d
--- /dev/null
+++ b/src/mongo/db/pipeline/value_comparator.cpp
@@ -0,0 +1,57 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/value_comparator.h"
+
+#include "mongo/util/assert_util.h"
+
+namespace mongo {
+
+bool ValueComparator::evaluate(Value::DeferredComparison deferredComparison) const {
+ int cmp = Value::compare(deferredComparison.lhs, deferredComparison.rhs, _stringComparator);
+ switch (deferredComparison.type) {
+ case Value::DeferredComparison::Type::kLT:
+ return cmp < 0;
+ case Value::DeferredComparison::Type::kLTE:
+ return cmp <= 0;
+ case Value::DeferredComparison::Type::kEQ:
+ return cmp == 0;
+ case Value::DeferredComparison::Type::kGTE:
+ return cmp >= 0;
+ case Value::DeferredComparison::Type::kGT:
+ return cmp > 0;
+ case Value::DeferredComparison::Type::kNE:
+ return cmp != 0;
+ }
+
+ MONGO_UNREACHABLE;
+}
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/value_comparator.h b/src/mongo/db/pipeline/value_comparator.h
new file mode 100644
index 00000000000..cbf423ccf92
--- /dev/null
+++ b/src/mongo/db/pipeline/value_comparator.h
@@ -0,0 +1,177 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#pragma once
+
+#include <map>
+#include <set>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "mongo/base/string_data.h"
+#include "mongo/db/pipeline/value.h"
+
+namespace mongo {
+
+class ValueComparator {
+public:
+ /**
+ * Functor compatible for use with unordered STL containers.
+ *
+ * TODO SERVER-23349: Remove the no-arguments constructor.
+ */
+ class EqualTo {
+ public:
+ EqualTo() = default;
+
+ explicit EqualTo(const ValueComparator* comparator) : _comparator(comparator) {}
+
+ bool operator()(const Value& lhs, const Value& rhs) const {
+ return _comparator ? _comparator->compare(lhs, rhs) == 0
+ : ValueComparator().compare(lhs, rhs) == 0;
+ }
+
+ private:
+ const ValueComparator* _comparator = nullptr;
+ };
+
+ /**
+ * Functor compatible for use with ordered STL containers.
+ */
+ class LessThan {
+ public:
+ explicit LessThan(const ValueComparator* comparator) : _comparator(comparator) {}
+
+ bool operator()(const Value& lhs, const Value& rhs) const {
+ return _comparator->compare(lhs, rhs) < 0;
+ }
+
+ private:
+ const ValueComparator* _comparator;
+ };
+
+ /**
+ * Constructs a value comparator with simple comparison semantics.
+ */
+ ValueComparator() = default;
+
+ /**
+ * Constructs a value comparator with special string comparison semantics.
+ */
+ ValueComparator(const StringData::ComparatorInterface* stringComparator)
+ : _stringComparator(stringComparator) {}
+
+ /**
+ * Returns <0 if 'lhs' is less than 'rhs', 0 if 'lhs' is equal to 'rhs', and >0 if 'lhs' is
+ * greater than 'rhs'.
+ */
+ int compare(const Value& lhs, const Value& rhs) const {
+ return Value::compare(lhs, rhs, _stringComparator);
+ }
+
+ /**
+ * Evaluates a deferred comparison object that was generated by invoking one of the comparison
+ * operators on the Value class.
+ */
+ bool evaluate(Value::DeferredComparison deferredComparison) const;
+
+ /**
+ * Returns a function object which computes whether one Value is equal to another under this
+ * comparator. This comparator must outlive the returned function object.
+ */
+ EqualTo getEqualTo() const {
+ return EqualTo(this);
+ }
+
+ /**
+ * Returns a function object which computes whether one Value is less than another under this
+ * comparator. This comparator must outlive the returned function object.
+ */
+ LessThan getLessThan() const {
+ return LessThan(this);
+ }
+
+ /**
+ * Construct an empty ordered set of Value whose ordering and equivalence classes are given by
+ * this comparator. This comparator must outlive the returned set.
+ */
+ std::set<Value, LessThan> makeOrderedValueSet() const {
+ return std::set<Value, LessThan>(LessThan(this));
+ }
+
+ /**
+ * Construct an empty unordered set of Value whose equivalence classes are given by this
+ * comparator. This comparator must outlive the returned set.
+ *
+ * TODO SERVER-23990: Make Value::Hash use the collation. The returned set won't be correctly
+ * collation-aware until this work is done.
+ */
+ std::unordered_set<Value, Value::Hash, EqualTo> makeUnorderedValueSet() const {
+ return std::unordered_set<Value, Value::Hash, EqualTo>(0, Value::Hash(), EqualTo(this));
+ }
+
+ /**
+ * Construct an empty ordered map from Value to type T whose ordering and equivalence classes
+ * are given by this comparator. This comparator must outlive the returned set.
+ */
+ template <typename T>
+ std::map<Value, T, LessThan> makeOrderedValueMap() const {
+ return std::map<Value, T, LessThan>(LessThan(this));
+ }
+
+ /**
+ * Construct an empty unordered map from Value to type T whose equivalence classes are given by
+ * this comparator. This comparator must outlive the returned set.
+ *
+ * TODO SERVER-23990: Make Value::Hash use the collation. The returned map won't be correctly
+ * collation-aware until this work is done.
+ */
+ template <typename T>
+ std::unordered_map<Value, T, Value::Hash, EqualTo> makeUnorderedValueMap() const {
+ return std::unordered_map<Value, T, Value::Hash, EqualTo>(0, Value::Hash(), EqualTo(this));
+ }
+
+private:
+ const StringData::ComparatorInterface* _stringComparator = nullptr;
+};
+
+//
+// Type aliases for sets and maps of Value for use by clients of the Document/Value library.
+//
+
+using ValueSet = std::set<Value, ValueComparator::LessThan>;
+
+using ValueUnorderedSet = std::unordered_set<Value, Value::Hash, ValueComparator::EqualTo>;
+
+template <typename T>
+using ValueMap = std::map<Value, T, ValueComparator::LessThan>;
+
+template <typename T>
+using ValueUnorderedMap = std::unordered_map<Value, T, Value::Hash, ValueComparator::EqualTo>;
+
+} // namespace mongo
diff --git a/src/mongo/db/pipeline/value_comparator_test.cpp b/src/mongo/db/pipeline/value_comparator_test.cpp
new file mode 100644
index 00000000000..05386061680
--- /dev/null
+++ b/src/mongo/db/pipeline/value_comparator_test.cpp
@@ -0,0 +1,223 @@
+/**
+ * Copyright (C) 2016 MongoDB Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License, version 3,
+ * as published by the Free Software Foundation.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the GNU Affero General Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/platform/basic.h"
+
+#include "mongo/db/pipeline/value_comparator.h"
+
+#include "mongo/db/pipeline/document_value_test_util.h"
+#include "mongo/db/query/collation/collator_interface_mock.h"
+#include "mongo/unittest/unittest.h"
+
+namespace mongo {
+namespace {
+
+TEST(ValueComparatorTest, EqualToEvaluatesCorrectly) {
+ Value val1("bar");
+ Value val2("bar");
+ Value val3("baz");
+ ASSERT_TRUE(ValueComparator().evaluate(val1 == val2));
+ ASSERT_FALSE(ValueComparator().evaluate(val1 == val3));
+}
+
+TEST(ValueComparatorTest, EqualToEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ Value val1("abc");
+ Value val2("def");
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2));
+}
+
+TEST(ValueComparatorTest, EqualToFunctorEvaluatesCorrectly) {
+ ValueComparator valueComparator;
+ auto equalFunc = valueComparator.getEqualTo();
+ Value val1("bar");
+ Value val2("bar");
+ Value val3("baz");
+ ASSERT_TRUE(equalFunc(val1, val2));
+ ASSERT_FALSE(equalFunc(val1, val3));
+}
+
+TEST(ValueComparatorTest, EqualToFunctorEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ ValueComparator valueComparator(&collator);
+ auto equalFunc = valueComparator.getEqualTo();
+ Value val1("abc");
+ Value val2("def");
+ ASSERT_TRUE(equalFunc(val1, val2));
+}
+
+TEST(ValueComparatorTest, NotEqualEvaluatesCorrectly) {
+ Value val1("bar");
+ Value val2("bar");
+ Value val3("baz");
+ ASSERT_FALSE(ValueComparator().evaluate(val1 != val2));
+ ASSERT_TRUE(ValueComparator().evaluate(val1 != val3));
+}
+
+TEST(ValueComparatorTest, NotEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ Value val1("abc");
+ Value val2("def");
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val1 != val2));
+}
+
+TEST(ValueComparatorTest, LessThanEvaluatesCorrectly) {
+ Value val1("a");
+ Value val2("b");
+ ASSERT_TRUE(ValueComparator().evaluate(val1 < val2));
+ ASSERT_FALSE(ValueComparator().evaluate(val2 < val1));
+}
+
+TEST(ValueComparatorTest, LessThanEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ Value val1("za");
+ Value val2("yb");
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 < val2));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val2 < val1));
+}
+
+TEST(ValueComparatorTest, LessThanFunctorEvaluatesCorrectly) {
+ ValueComparator valueComparator;
+ auto lessThanFunc = valueComparator.getLessThan();
+ Value val1("a");
+ Value val2("b");
+ ASSERT_TRUE(lessThanFunc(val1, val2));
+ ASSERT_FALSE(lessThanFunc(val2, val1));
+}
+
+TEST(ValueComparatorTest, LessThanFunctorEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ ValueComparator valueComparator(&collator);
+ auto lessThanFunc = valueComparator.getLessThan();
+ Value val1("za");
+ Value val2("yb");
+ ASSERT_TRUE(lessThanFunc(val1, val2));
+ ASSERT_FALSE(lessThanFunc(val2, val1));
+}
+
+TEST(ValueComparatorTest, LessThanOrEqualEvaluatesCorrectly) {
+ Value val1("a");
+ Value val2("a");
+ Value val3("b");
+ ASSERT_TRUE(ValueComparator().evaluate(val1 <= val2));
+ ASSERT_TRUE(ValueComparator().evaluate(val2 <= val1));
+ ASSERT_TRUE(ValueComparator().evaluate(val1 <= val3));
+ ASSERT_FALSE(ValueComparator().evaluate(val3 <= val1));
+}
+
+TEST(ValueComparatorTest, LessThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ Value val1("za");
+ Value val2("za");
+ Value val3("yb");
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 <= val2));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 <= val1));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 <= val3));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 <= val1));
+}
+
+TEST(ValueComparatorTest, GreaterThanEvaluatesCorrectly) {
+ Value val1("b");
+ Value val2("a");
+ ASSERT_TRUE(ValueComparator().evaluate(val1 > val2));
+ ASSERT_FALSE(ValueComparator().evaluate(val2 > val1));
+}
+
+TEST(ValueComparatorTest, GreaterThanEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ Value val1("yb");
+ Value val2("za");
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 > val2));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val2 > val1));
+}
+
+TEST(ValueComparatorTest, GreaterThanOrEqualEvaluatesCorrectly) {
+ Value val1("b");
+ Value val2("b");
+ Value val3("a");
+ ASSERT_TRUE(ValueComparator().evaluate(val1 >= val2));
+ ASSERT_TRUE(ValueComparator().evaluate(val2 >= val1));
+ ASSERT_TRUE(ValueComparator().evaluate(val1 >= val3));
+ ASSERT_FALSE(ValueComparator().evaluate(val3 >= val1));
+}
+
+TEST(ValueComparatorTest, GreaterThanOrEqualEvaluatesCorrectlyWithNonSimpleCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ Value val1("yb");
+ Value val2("yb");
+ Value val3("za");
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 >= val2));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 >= val1));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 >= val3));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 >= val1));
+}
+
+TEST(ValueComparatorTest, OrderedValueSetRespectsTheComparator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kReverseString);
+ ValueComparator valueComparator(&collator);
+ ValueSet set = valueComparator.makeOrderedValueSet();
+ set.insert(Value("yb"));
+ set.insert(Value("za"));
+
+ auto it = set.begin();
+ ASSERT_VALUE_EQ(*it, Value("za"));
+ ++it;
+ ASSERT_VALUE_EQ(*it, Value("yb"));
+ ++it;
+ ASSERT(it == set.end());
+}
+
+TEST(ValueComparatorTest, EqualToEvaluatesCorrectlyWithNumbers) {
+ Value val1(88);
+ Value val2(88);
+ Value val3(99);
+ ASSERT_TRUE(ValueComparator().evaluate(val1 == val2));
+ ASSERT_FALSE(ValueComparator().evaluate(val1 == val3));
+}
+
+TEST(ValueComparatorTest, NestedObjectEqualityRespectsCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ Value val1(Document{{"foo", "abc"}});
+ Value val2(Document{{"foo", "def"}});
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 == val1));
+}
+
+TEST(ValueComparatorTest, NestedArrayEqualityRespectsCollator) {
+ CollatorInterfaceMock collator(CollatorInterfaceMock::MockType::kAlwaysEqual);
+ Value val1(std::vector<Value>{Value("a"), Value("b")});
+ Value val2(std::vector<Value>{Value("c"), Value("d")});
+ Value val3(std::vector<Value>{Value("c"), Value("d"), Value("e")});
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val1 == val2));
+ ASSERT_TRUE(ValueComparator(&collator).evaluate(val2 == val1));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val1 == val3));
+ ASSERT_FALSE(ValueComparator(&collator).evaluate(val3 == val1));
+}
+
+} // namespace
+} // namespace mongo
diff --git a/src/mongo/dbtests/SConscript b/src/mongo/dbtests/SConscript
index 6b60b1372d8..54134152792 100644
--- a/src/mongo/dbtests/SConscript
+++ b/src/mongo/dbtests/SConscript
@@ -22,8 +22,8 @@ env.Library(
'framework_options_init.cpp',
],
LIBDEPS=[
- '$BUILD_DIR/mongo/s/coreshard',
'$BUILD_DIR/mongo/db/s/sharding',
+ '$BUILD_DIR/mongo/s/coreshard',
'$BUILD_DIR/mongo/unittest/unittest',
'framework_options',
],
@@ -112,6 +112,7 @@ dbtest = env.Program(
LIBDEPS=[
"$BUILD_DIR/mongo/db/auth/authmocks",
"$BUILD_DIR/mongo/db/bson/dotted_path_support",
+ '$BUILD_DIR/mongo/db/pipeline/document_value_test_util',
"$BUILD_DIR/mongo/db/query/collation/collator_interface_mock",
"$BUILD_DIR/mongo/db/query/query",
"$BUILD_DIR/mongo/db/query/query_test_service_context",
diff --git a/src/mongo/dbtests/documentsourcetests.cpp b/src/mongo/dbtests/documentsourcetests.cpp
index 973629cdabe..c98dcdd7a53 100644
--- a/src/mongo/dbtests/documentsourcetests.cpp
+++ b/src/mongo/dbtests/documentsourcetests.cpp
@@ -39,6 +39,7 @@
#include "mongo/db/matcher/extensions_callback_disallow_extensions.h"
#include "mongo/db/pipeline/dependencies.h"
#include "mongo/db/pipeline/document_source.h"
+#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/pipeline.h"
#include "mongo/db/query/get_executor.h"
@@ -152,7 +153,7 @@ public:
// The cursor will produce the expected result.
boost::optional<Document> next = source()->getNext();
ASSERT(bool(next));
- ASSERT_EQUALS(Value(1), next->getField("a"));
+ ASSERT_VALUE_EQ(Value(1), next->getField("a"));
// There are no more results.
ASSERT(!source()->getNext());
// Exhausting the source releases the read lock.
@@ -186,11 +187,11 @@ public:
// The result is as expected.
boost::optional<Document> next = source()->getNext();
ASSERT(bool(next));
- ASSERT_EQUALS(Value(1), next->getField("a"));
+ ASSERT_VALUE_EQ(Value(1), next->getField("a"));
// The next result is as expected.
next = source()->getNext();
ASSERT(bool(next));
- ASSERT_EQUALS(Value(2), next->getField("a"));
+ ASSERT_VALUE_EQ(Value(2), next->getField("a"));
// The DocumentSourceCursor doesn't hold a read lock.
ASSERT(!_opCtx.lockState()->isReadLocked());
source()->dispose();