summaryrefslogtreecommitdiff
path: root/src/mongo
diff options
context:
space:
mode:
authorDavid Storch <david.storch@10gen.com>2016-07-18 21:32:00 -0400
committerDavid Storch <david.storch@10gen.com>2016-07-19 22:39:30 -0400
commit015f840701e270e42514e53a3def41d796d414a6 (patch)
tree0e63aaf091645ad1f3d460e279ef47fa7679621b /src/mongo
parent7e986cc77f121e3af9a5f1217e89913745fc07f9 (diff)
downloadmongo-015f840701e270e42514e53a3def41d796d414a6.tar.gz
SERVER-23349 require a collator for Document::compare() and Value::compare()
Includes making aggregation $sort respect the collation.
Diffstat (limited to 'src/mongo')
-rw-r--r--src/mongo/db/pipeline/accumulator_min_max.cpp2
-rw-r--r--src/mongo/db/pipeline/accumulator_test.cpp6
-rw-r--r--src/mongo/db/pipeline/document.h2
-rw-r--r--src/mongo/db/pipeline/document_comparator.h8
-rw-r--r--src/mongo/db/pipeline/document_source_group.cpp22
-rw-r--r--src/mongo/db/pipeline/document_source_sort.cpp6
-rw-r--r--src/mongo/db/pipeline/document_source_test.cpp9
-rw-r--r--src/mongo/db/pipeline/document_value_test.cpp6
-rw-r--r--src/mongo/db/pipeline/expression.cpp2
-rw-r--r--src/mongo/db/pipeline/expression.h1
-rw-r--r--src/mongo/db/pipeline/expression_test.cpp6
-rw-r--r--src/mongo/db/pipeline/value.h2
12 files changed, 56 insertions, 16 deletions
diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp
index fc854ca1ddb..7eb77f74920 100644
--- a/src/mongo/db/pipeline/accumulator_min_max.cpp
+++ b/src/mongo/db/pipeline/accumulator_min_max.cpp
@@ -51,7 +51,7 @@ void AccumulatorMinMax::processInternal(const Value& input, bool merging) {
// nullish values should have no impact on result
if (!input.nullish()) {
/* compare with the current value; swap if appropriate */
- int cmp = Value::compare(_val, input) * _sense;
+ int cmp = getExpressionContext()->getValueComparator().compare(_val, input) * _sense;
if (cmp > 0 || _val.missing()) { // missing is lower than all other values
_val = input;
_memUsageBytes = sizeof(*this) + input.getApproximateSize() - sizeof(Value);
diff --git a/src/mongo/db/pipeline/accumulator_test.cpp b/src/mongo/db/pipeline/accumulator_test.cpp
index 8fad2b0095d..f903e897edc 100644
--- a/src/mongo/db/pipeline/accumulator_test.cpp
+++ b/src/mongo/db/pipeline/accumulator_test.cpp
@@ -49,11 +49,13 @@ static void assertExpectedResults(
std::string accumulator,
std::initializer_list<std::pair<std::vector<Value>, Value>> operations) {
auto factory = Accumulator::getFactory(accumulator);
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
for (auto&& op : operations) {
try {
// Asserts that result equals expected result when not sharded.
{
boost::intrusive_ptr<Accumulator> accum = factory();
+ accum->injectExpressionContext(expCtx);
for (auto&& val : op.first) {
accum->process(val, false);
}
@@ -65,7 +67,9 @@ static void assertExpectedResults(
// Asserts that result equals expected result when all input is on one shard.
{
boost::intrusive_ptr<Accumulator> accum = factory();
+ accum->injectExpressionContext(expCtx);
boost::intrusive_ptr<Accumulator> shard = factory();
+ shard->injectExpressionContext(expCtx);
for (auto&& val : op.first) {
shard->process(val, false);
}
@@ -78,8 +82,10 @@ static void assertExpectedResults(
// Asserts that result equals expected result when each input is on a separate shard.
{
boost::intrusive_ptr<Accumulator> accum = factory();
+ accum->injectExpressionContext(expCtx);
for (auto&& val : op.first) {
boost::intrusive_ptr<Accumulator> shard = factory();
+ shard->injectExpressionContext(expCtx);
shard->process(val, false);
accum->process(shard->getValue(true), true);
}
diff --git a/src/mongo/db/pipeline/document.h b/src/mongo/db/pipeline/document.h
index 288036a1ba4..5dc14ab7aa5 100644
--- a/src/mongo/db/pipeline/document.h
+++ b/src/mongo/db/pipeline/document.h
@@ -176,7 +176,7 @@ public:
*/
static int compare(const Document& lhs,
const Document& rhs,
- const StringData::ComparatorInterface* stringComparator = nullptr);
+ const StringData::ComparatorInterface* stringComparator);
std::string toString() const;
diff --git a/src/mongo/db/pipeline/document_comparator.h b/src/mongo/db/pipeline/document_comparator.h
index 91c1e08da58..0aeedfb20e6 100644
--- a/src/mongo/db/pipeline/document_comparator.h
+++ b/src/mongo/db/pipeline/document_comparator.h
@@ -52,6 +52,14 @@ public:
*/
bool evaluate(Document::DeferredComparison deferredComparison) const;
+ /**
+ * Returns <0 if 'lhs' is less than 'rhs', 0 if 'lhs' is equal to 'rhs', and >0 if 'lhs' is
+ * greater than 'rhs'.
+ */
+ int compare(const Document& lhs, const Document& rhs) const {
+ return Document::compare(lhs, rhs, _stringComparator);
+ }
+
private:
const StringData::ComparatorInterface* _stringComparator = nullptr;
};
diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp
index 5b97e76d370..0a37706b455 100644
--- a/src/mongo/db/pipeline/document_source_group.cpp
+++ b/src/mongo/db/pipeline/document_source_group.cpp
@@ -35,6 +35,7 @@
#include "mongo/db/pipeline/expression.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/db/pipeline/value_comparator.h"
namespace mongo {
@@ -365,16 +366,27 @@ using GroupsMap = DocumentSourceGroup::GroupsMap;
class SorterComparator {
public:
typedef pair<Value, Value> Data;
+
+ SorterComparator(ValueComparator valueComparator) : _valueComparator(valueComparator) {}
+
int operator()(const Data& lhs, const Data& rhs) const {
- return Value::compare(lhs.first, rhs.first);
+ return _valueComparator.compare(lhs.first, rhs.first);
}
+
+private:
+ ValueComparator _valueComparator;
};
class SpillSTLComparator {
public:
+ SpillSTLComparator(ValueComparator valueComparator) : _valueComparator(valueComparator) {}
+
bool operator()(const GroupsMap::value_type* lhs, const GroupsMap::value_type* rhs) const {
- return Value::compare(lhs->first, rhs->first) < 0;
+ return _valueComparator.evaluate(lhs->first < rhs->first);
}
+
+private:
+ ValueComparator _valueComparator;
};
bool containsOnlyFieldPathsAndConstants(ExpressionObject* expressionObj) {
@@ -552,8 +564,8 @@ void DocumentSourceGroup::initialize() {
// We won't be using groups again so free its memory.
_groups = pExpCtx->getValueComparator().makeUnorderedValueMap<Accumulators>();
- _sorterIterator.reset(
- Sorter<Value, Value>::Iterator::merge(sortedFiles, SortOptions(), SorterComparator()));
+ _sorterIterator.reset(Sorter<Value, Value>::Iterator::merge(
+ sortedFiles, SortOptions(), SorterComparator(pExpCtx->getValueComparator())));
// prepare current to accumulate data
_currentAccumulators.reserve(numAccumulators);
@@ -577,7 +589,7 @@ shared_ptr<Sorter<Value, Value>::Iterator> DocumentSourceGroup::spill() {
ptrs.push_back(&*it);
}
- stable_sort(ptrs.begin(), ptrs.end(), SpillSTLComparator());
+ stable_sort(ptrs.begin(), ptrs.end(), SpillSTLComparator(pExpCtx->getValueComparator()));
SortedFileWriter<Value, Value> writer(SortOptions().TempDir(pExpCtx->tempDir));
switch (vpAccumulatorFactory.size()) { // same as ptrs[i]->second.size() for all i.
diff --git a/src/mongo/db/pipeline/document_source_sort.cpp b/src/mongo/db/pipeline/document_source_sort.cpp
index 345383750cc..c4382ad27da 100644
--- a/src/mongo/db/pipeline/document_source_sort.cpp
+++ b/src/mongo/db/pipeline/document_source_sort.cpp
@@ -325,14 +325,14 @@ int DocumentSourceSort::compare(const Value& lhs, const Value& rhs) const {
const size_t n = vSortKey.size();
if (n == 1) { // simple fast case
if (vAscending[0])
- return Value::compare(lhs, rhs);
+ return pExpCtx->getValueComparator().compare(lhs, rhs);
else
- return -Value::compare(lhs, rhs);
+ return -pExpCtx->getValueComparator().compare(lhs, rhs);
}
// compound sort
for (size_t i = 0; i < n; i++) {
- int cmp = Value::compare(lhs[i], rhs[i]);
+ int cmp = pExpCtx->getValueComparator().compare(lhs[i], rhs[i]);
if (cmp) {
/* if necessary, adjust the return value by the key ordering */
if (!vAscending[i])
diff --git a/src/mongo/db/pipeline/document_source_test.cpp b/src/mongo/db/pipeline/document_source_test.cpp
index cf13e1589b1..eb38ddc7436 100644
--- a/src/mongo/db/pipeline/document_source_test.cpp
+++ b/src/mongo/db/pipeline/document_source_test.cpp
@@ -36,6 +36,7 @@
#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/expression_context.h"
#include "mongo/db/pipeline/pipeline.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/db/service_context.h"
#include "mongo/db/service_context_noop.h"
#include "mongo/db/storage/storage_options.h"
@@ -726,7 +727,7 @@ class AggregateOperatorExpression : public ExpressionBase {
struct ValueCmp {
bool operator()(const Value& a, const Value& b) const {
- return Value::compare(a, b) < 0;
+ return ValueComparator().evaluate(a < b);
}
};
typedef map<Value, Document, ValueCmp> IdMap;
@@ -3728,7 +3729,11 @@ TEST_F(BucketReturnsGroupAndSort, BucketSucceedsWithMultipleBoundaryValues) {
class InvalidBucketSpec : public Mock::Base, public unittest::Test {
public:
vector<intrusive_ptr<DocumentSource>> createBucket(BSONObj bucketSpec) {
- return DocumentSourceBucket::createFromBson(bucketSpec.firstElement(), ctx());
+ auto sources = DocumentSourceBucket::createFromBson(bucketSpec.firstElement(), ctx());
+ for (auto&& source : sources) {
+ source->injectExpressionContext(ctx());
+ }
+ return sources;
}
};
diff --git a/src/mongo/db/pipeline/document_value_test.cpp b/src/mongo/db/pipeline/document_value_test.cpp
index 369f988eb6d..8c5f9a755c0 100644
--- a/src/mongo/db/pipeline/document_value_test.cpp
+++ b/src/mongo/db/pipeline/document_value_test.cpp
@@ -31,9 +31,11 @@
#include "mongo/platform/basic.h"
#include "mongo/db/pipeline/document.h"
+#include "mongo/db/pipeline/document_comparator.h"
#include "mongo/db/pipeline/document_value_test_util.h"
#include "mongo/db/pipeline/field_path.h"
#include "mongo/db/pipeline/value.h"
+#include "mongo/db/pipeline/value_comparator.h"
#include "mongo/dbtests/dbtests.h"
#include "mongo/util/print.h"
@@ -244,7 +246,7 @@ public:
public:
int cmp(const BSONObj& a, const BSONObj& b) {
- int result = Document::compare(fromBson(a), fromBson(b));
+ int result = DocumentComparator().compare(fromBson(a), fromBson(b));
return // sign
result < 0 ? -1 : result > 0 ? 1 : 0;
}
@@ -1584,7 +1586,7 @@ private:
return 1;
}
int cmp(const Value& a, const Value& b) {
- return sign(Value::compare(a, b));
+ return sign(ValueComparator().compare(a, b));
}
void assertComparison(int expectedResult, const BSONObj& a, const BSONObj& b) {
assertComparison(expectedResult, fromBson(a), fromBson(b));
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 40affa27bbe..61852cb6f2f 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -735,7 +735,7 @@ Value ExpressionCompare::evaluateInternal(Variables* vars) const {
Value pLeft(vpOperand[0]->evaluateInternal(vars));
Value pRight(vpOperand[1]->evaluateInternal(vars));
- int cmp = Value::compare(pLeft, pRight);
+ int cmp = getExpressionContext()->getValueComparator().compare(pLeft, pRight);
// Make cmp one of 1, 0, or -1.
if (cmp == 0) {
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 4a849675300..a324cd43b5a 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -421,6 +421,7 @@ class ExpressionFromAccumulator
public:
Value evaluateInternal(Variables* vars) const final {
Accumulator accum;
+ accum.injectExpressionContext(this->getExpressionContext());
const size_t n = this->vpOperand.size();
// If a single array arg is given, loop through it passing each member to the accumulator.
// If a single, non-array arg is given, pass it directly to the accumulator.
diff --git a/src/mongo/db/pipeline/expression_test.cpp b/src/mongo/db/pipeline/expression_test.cpp
index 9828e719aab..144f2538723 100644
--- a/src/mongo/db/pipeline/expression_test.cpp
+++ b/src/mongo/db/pipeline/expression_test.cpp
@@ -1575,7 +1575,9 @@ public:
BSONElement specElement = specObject.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
intrusive_ptr<Expression> optimized = expression->optimize();
ASSERT_EQUALS(constify(expectedOptimized()), expressionToBson(optimized));
}
@@ -1606,7 +1608,9 @@ public:
BSONElement specElement = specObject.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
// Check expression spec round trip.
ASSERT_EQUALS(constify(spec()), expressionToBson(expression));
// Check evaluation result.
@@ -1849,7 +1853,9 @@ public:
BSONElement specElement = specObject.firstElement();
VariablesIdGenerator idGenerator;
VariablesParseState vps(&idGenerator);
+ intrusive_ptr<ExpressionContext> expCtx(new ExpressionContext());
intrusive_ptr<Expression> expression = Expression::parseOperand(specElement, vps);
+ expression->injectExpressionContext(expCtx);
ASSERT_VALUE_EQ(expression->evaluate(Document()), Value(true));
}
};
diff --git a/src/mongo/db/pipeline/value.h b/src/mongo/db/pipeline/value.h
index 6914ff10014..b4dfc6a4943 100644
--- a/src/mongo/db/pipeline/value.h
+++ b/src/mongo/db/pipeline/value.h
@@ -246,7 +246,7 @@ public:
*/
static int compare(const Value& lhs,
const Value& rhs,
- const StringData::ComparatorInterface* stringComparator = nullptr);
+ const StringData::ComparatorInterface* stringComparator);
friend DeferredComparison operator==(const Value& lhs, const Value& rhs) {
return DeferredComparison(DeferredComparison::Type::kEQ, lhs, rhs);