diff options
Diffstat (limited to 'src/mongo/db/matcher')
-rw-r--r-- | src/mongo/db/matcher/expression_leaf.cpp | 4 | ||||
-rw-r--r-- | src/mongo/db/matcher/expression_leaf.h | 17 | ||||
-rw-r--r-- | src/mongo/db/matcher/expression_parser.cpp | 3 | ||||
-rw-r--r-- | src/mongo/db/matcher/expression_parser_leaf_test.cpp | 59 |
4 files changed, 75 insertions, 8 deletions
diff --git a/src/mongo/db/matcher/expression_leaf.cpp b/src/mongo/db/matcher/expression_leaf.cpp index fd9d082c518..eb1e77d88d9 100644 --- a/src/mongo/db/matcher/expression_leaf.cpp +++ b/src/mongo/db/matcher/expression_leaf.cpp @@ -289,7 +289,7 @@ void RegexMatchExpression::shortDebugString(StringBuilder& debug) const { // --------- -Status ModMatchExpression::init(StringData path, int divisor, int remainder) { +Status ModMatchExpression::init(StringData path, long long divisor, long long remainder) { if (divisor == 0) return Status(ErrorCodes::BadValue, "divisor cannot be 0"); _divisor = divisor; @@ -300,7 +300,7 @@ Status ModMatchExpression::init(StringData path, int divisor, int remainder) { bool ModMatchExpression::matchesSingleElement(const BSONElement& e, MatchDetails* details) const { if (!e.isNumber()) return false; - return mongoSafeMod(e.numberLong(), static_cast<long long>(_divisor)) == _remainder; + return mongoSafeMod(truncateToLong(e), _divisor) == _remainder; } void ModMatchExpression::debugString(StringBuilder& debug, int level) const { diff --git a/src/mongo/db/matcher/expression_leaf.h b/src/mongo/db/matcher/expression_leaf.h index f7e1ea078dc..90e8463d196 100644 --- a/src/mongo/db/matcher/expression_leaf.h +++ b/src/mongo/db/matcher/expression_leaf.h @@ -335,7 +335,7 @@ class ModMatchExpression : public LeafMatchExpression { public: ModMatchExpression() : LeafMatchExpression(MOD) {} - Status init(StringData path, int divisor, int remainder); + Status init(StringData path, long long divisor, long long remainder); virtual std::unique_ptr<MatchExpression> shallowClone() const { std::unique_ptr<ModMatchExpression> m = stdx::make_unique<ModMatchExpression>(); @@ -354,20 +354,27 @@ public: virtual bool equivalent(const MatchExpression* other) const; - int getDivisor() const { + long long getDivisor() const { return _divisor; } - int getRemainder() const { + long long getRemainder() const { return _remainder; } + static long long truncateToLong(const BSONElement& element) { + if (element.type() == BSONType::NumberDecimal) { + return element.numberDecimal().toLong(Decimal128::kRoundTowardZero); + } + return element.numberLong(); + } + private: ExpressionOptimizerFunc getOptimizer() const final { return [](std::unique_ptr<MatchExpression> expression) { return expression; }; } - int _divisor; - int _remainder; + long long _divisor; + long long _remainder; }; class ExistsMatchExpression : public LeafMatchExpression { diff --git a/src/mongo/db/matcher/expression_parser.cpp b/src/mongo/db/matcher/expression_parser.cpp index e1068520e2e..68b1702c7c5 100644 --- a/src/mongo/db/matcher/expression_parser.cpp +++ b/src/mongo/db/matcher/expression_parser.cpp @@ -565,7 +565,8 @@ StatusWithMatchExpression parseMOD(StringData name, BSONElement e) { return {Status(ErrorCodes::BadValue, "malformed mod, too many elements")}; auto temp = stdx::make_unique<ModMatchExpression>(); - auto s = temp->init(name, d.numberInt(), r.numberInt()); + auto s = temp->init( + name, ModMatchExpression::truncateToLong(d), ModMatchExpression::truncateToLong(r)); if (!s.isOK()) return s; return {std::move(temp)}; diff --git a/src/mongo/db/matcher/expression_parser_leaf_test.cpp b/src/mongo/db/matcher/expression_parser_leaf_test.cpp index a1a79f8f450..18a25a410ae 100644 --- a/src/mongo/db/matcher/expression_parser_leaf_test.cpp +++ b/src/mongo/db/matcher/expression_parser_leaf_test.cpp @@ -349,6 +349,65 @@ TEST(MatchExpressionParserLeafTest, SimpleModNotNumber) { << "a"))); } +TEST(MatchExpressionParserLeafTest, ModFloatTruncate) { + struct TestCase { + BSONObj _query; + long long _divider; + long long _remainder; + }; + + const auto positiveLargerThanInt = 3 * static_cast<int64_t>(std::numeric_limits<int>::max()); + const auto negativeSmallerThanInt = 3 * static_cast<int64_t>(std::numeric_limits<int>::min()); + std::vector<TestCase> testCases = { + {BSON("x" << BSON("$mod" << BSON_ARRAY(3 << 2))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(3LL << 2LL))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(3.2 << 2.2))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(3.7 << 2.7))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3") << Decimal128("2")))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3.2") << Decimal128("2.2")))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("3.7") << Decimal128("2.7")))), 3, 2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(positiveLargerThanInt << positiveLargerThanInt))), + positiveLargerThanInt, + positiveLargerThanInt}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(static_cast<double>(positiveLargerThanInt) + << static_cast<double>(positiveLargerThanInt)))), + positiveLargerThanInt, + positiveLargerThanInt}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128(positiveLargerThanInt) + << Decimal128(positiveLargerThanInt)))), + positiveLargerThanInt, + positiveLargerThanInt}, + + {BSON("x" << BSON("$mod" << BSON_ARRAY(-3 << -2))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(-3LL << -2LL))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(-3.2 << -2.2))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(-3.7 << -2.7))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3") << Decimal128("-2")))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3.2") << Decimal128("-2.2")))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128("-3.7") << Decimal128("-2.7")))), -3, -2}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(negativeSmallerThanInt << negativeSmallerThanInt))), + negativeSmallerThanInt, + negativeSmallerThanInt}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(static_cast<double>(negativeSmallerThanInt) + << static_cast<double>(negativeSmallerThanInt)))), + negativeSmallerThanInt, + negativeSmallerThanInt}, + {BSON("x" << BSON("$mod" << BSON_ARRAY(Decimal128(negativeSmallerThanInt) + << Decimal128(negativeSmallerThanInt)))), + negativeSmallerThanInt, + negativeSmallerThanInt}, + }; + + for (const auto& testCase : testCases) { + boost::intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest()); + StatusWithMatchExpression result = MatchExpressionParser::parse(testCase._query, expCtx); + ASSERT_OK(result.getStatus()); + auto modExpr = static_cast<ModMatchExpression*>(result.getValue().get()); + ASSERT_EQ(modExpr->getDivisor(), testCase._divider); + ASSERT_EQ(modExpr->getRemainder(), testCase._remainder); + } +} + TEST(MatchExpressionParserLeafTest, IdCollation) { BSONObj query = BSON("$id" << "string"); |