summaryrefslogtreecommitdiff
path: root/src/mongo/db/matcher
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/matcher')
-rw-r--r--src/mongo/db/matcher/expression_leaf.cpp4
-rw-r--r--src/mongo/db/matcher/expression_leaf.h17
-rw-r--r--src/mongo/db/matcher/expression_parser.cpp3
-rw-r--r--src/mongo/db/matcher/expression_parser_leaf_test.cpp59
4 files changed, 75 insertions, 8 deletions
diff --git a/src/mongo/db/matcher/expression_leaf.cpp b/src/mongo/db/matcher/expression_leaf.cpp
index fd9d082c518..eb1e77d88d9 100644
--- a/src/mongo/db/matcher/expression_leaf.cpp
+++ b/src/mongo/db/matcher/expression_leaf.cpp
@@ -289,7 +289,7 @@ void RegexMatchExpression::shortDebugString(StringBuilder& debug) const {
// ---------
-Status ModMatchExpression::init(StringData path, int divisor, int remainder) {
+Status ModMatchExpression::init(StringData path, long long divisor, long long remainder) {
if (divisor == 0)
return Status(ErrorCodes::BadValue, "divisor cannot be 0");
_divisor = divisor;
@@ -300,7 +300,7 @@ Status ModMatchExpression::init(StringData path, int divisor, int remainder) {
bool ModMatchExpression::matchesSingleElement(const BSONElement& e, MatchDetails* details) const {
if (!e.isNumber())
return false;
- return mongoSafeMod(e.numberLong(), static_cast<long long>(_divisor)) == _remainder;
+ return mongoSafeMod(truncateToLong(e), _divisor) == _remainder;
}
void ModMatchExpression::debugString(StringBuilder& debug, int level) const {
diff --git a/src/mongo/db/matcher/expression_leaf.h b/src/mongo/db/matcher/expression_leaf.h
index f7e1ea078dc..90e8463d196 100644
--- a/src/mongo/db/matcher/expression_leaf.h
+++ b/src/mongo/db/matcher/expression_leaf.h
@@ -335,7 +335,7 @@ class ModMatchExpression : public LeafMatchExpression {
public:
ModMatchExpression() : LeafMatchExpression(MOD) {}
- Status init(StringData path, int divisor, int remainder);
+ Status init(StringData path, long long divisor, long long remainder);
virtual std::unique_ptr<MatchExpression> shallowClone() const {
std::unique_ptr<ModMatchExpression> m = stdx::make_unique<ModMatchExpression>();
@@ -354,20 +354,27 @@ public:
virtual bool equivalent(const MatchExpression* other) const;
- int getDivisor() const {
+ long long getDivisor() const {
return _divisor;
}
- int getRemainder() const {
+ long long getRemainder() const {
return _remainder;
}
+ static long long truncateToLong(const BSONElement& element) {
+ if (element.type() == BSONType::NumberDecimal) {
+ return element.numberDecimal().toLong(Decimal128::kRoundTowardZero);
+ }
+ return element.numberLong();
+ }
+
private:
ExpressionOptimizerFunc getOptimizer() const final {
return [](std::unique_ptr<MatchExpression> expression) { return expression; };
}
- int _divisor;
- int _remainder;
+ long long _divisor;
+ long long _remainder;
};
class ExistsMatchExpression : public LeafMatchExpression {
diff --git a/src/mongo/db/matcher/expression_parser.cpp b/src/mongo/db/matcher/expression_parser.cpp
index e1068520e2e..68b1702c7c5 100644
--- a/src/mongo/db/matcher/expression_parser.cpp
+++ b/src/mongo/db/matcher/expression_parser.cpp
@@ -565,7 +565,8 @@ StatusWithMatchExpression parseMOD(StringData name, BSONElement e) {
return {Status(ErrorCodes::BadValue, "malformed mod, too many elements")};
auto temp = stdx::make_unique<ModMatchExpression>();
- auto s = temp->init(name, d.numberInt(), r.numberInt());
+ auto s = temp->init(
+ name, ModMatchExpression::truncateToLong(d), ModMatchExpression::truncateToLong(r));
if (!s.isOK())
return s;
return {std::move(temp)};
diff --git a/src/mongo/db/matcher/expression_parser_leaf_test.cpp b/src/mongo/db/matcher/expression_parser_leaf_test.cpp
index a1a79f8f450..18a25a410ae 100644
--- a/src/mongo/db/matcher/expression_parser_leaf_test.cpp
+++ b/src/mongo/db/matcher/expression_parser_leaf_test.cpp
@@ -349,6 +349,65 @@ TEST(MatchExpressionParserLeafTest, SimpleModNotNumber) {
<< "a")));
}
+TEST(MatchExpressionParserLeafTest, ModFloatTruncate) {
+ struct TestCase {
+ BSONObj _query;
+ long long _divider;
+ long long _remainder;
+ };
+
+ const auto positiveLargerThanInt = 3 * static_cast<int64_t>(std::numeric_limits<int>::max());
+ const auto negativeSmallerThanInt = 3 * static_cast<int64_t>(std::numeric_limits<int>::min());
+ std::vector<TestCase> testCases = {
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(3 << 2))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(3LL << 2LL))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(3.2 << 2.2))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(3.7 << 2.7))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3") << Decimal128("2")))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3.2") << Decimal128("2.2")))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3.7") << Decimal128("2.7")))), 3, 2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(positiveLargerThanInt << positiveLargerThanInt))),
+ positiveLargerThanInt,
+ positiveLargerThanInt},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(static_cast<double>(positiveLargerThanInt)
+ << static_cast<double>(positiveLargerThanInt)))),
+ positiveLargerThanInt,
+ positiveLargerThanInt},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128(positiveLargerThanInt)
+ << Decimal128(positiveLargerThanInt)))),
+ positiveLargerThanInt,
+ positiveLargerThanInt},
+
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(-3 << -2))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(-3LL << -2LL))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(-3.2 << -2.2))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(-3.7 << -2.7))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3") << Decimal128("-2")))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3.2") << Decimal128("-2.2")))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3.7") << Decimal128("-2.7")))), -3, -2},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(negativeSmallerThanInt << negativeSmallerThanInt))),
+ negativeSmallerThanInt,
+ negativeSmallerThanInt},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(static_cast<double>(negativeSmallerThanInt)
+ << static_cast<double>(negativeSmallerThanInt)))),
+ negativeSmallerThanInt,
+ negativeSmallerThanInt},
+ {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128(negativeSmallerThanInt)
+ << Decimal128(negativeSmallerThanInt)))),
+ negativeSmallerThanInt,
+ negativeSmallerThanInt},
+ };
+
+ for (const auto& testCase : testCases) {
+ boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
+ StatusWithMatchExpression result = MatchExpressionParser::parse(testCase._query, expCtx);
+ ASSERT_OK(result.getStatus());
+ auto modExpr = static_cast<ModMatchExpression*>(result.getValue().get());
+ ASSERT_EQ(modExpr->getDivisor(), testCase._divider);
+ ASSERT_EQ(modExpr->getRemainder(), testCase._remainder);
+ }
+}
+
TEST(MatchExpressionParserLeafTest, IdCollation) {
BSONObj query = BSON("$id"
<< "string");