From 1b4a551a6b8c85611e26857217ce1a1e1363e716 Mon Sep 17 00:00:00 2001 From: Rushan Chen Date: Wed, 17 May 2023 09:22:28 +0000 Subject: SERVER-75287: fix rounding and overflow detection in Classic for subtract when one operand is Date and another numeric --- jstests/aggregation/expressions/add.js | 47 ++++++++++++++++++++++++ jstests/aggregation/expressions/subtract.js | 44 +++++++++++++++++++++++ src/mongo/db/pipeline/expression.cpp | 55 ++++++++++++++++++++++------- 3 files changed, 134 insertions(+), 12 deletions(-) diff --git a/jstests/aggregation/expressions/add.js b/jstests/aggregation/expressions/add.js index cc074a9ae08..c411c3f213b 100644 --- a/jstests/aggregation/expressions/add.js +++ b/jstests/aggregation/expressions/add.js @@ -1,5 +1,7 @@ (function() { "use strict"; +load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and assertErrMsgContains. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. // In SERVER-63012, translation of $add expression into sbe now defaults the translation of $add // with no operands to a zero integer constant. @@ -43,4 +45,49 @@ let addResult = coll.aggregate([{$project: {add: {$add: queryArr}}}]).toArray(); let sumResult = coll.aggregate([{$project: {sum: {$sum: queryArr}}}]).toArray(); assert.neq(addResult[0]["add"], sumResult[0]["sum"]); assert.eq(addResult[0]["add"], arr.reduce((a, b) => a + b)); + +assert.eq(true, coll.drop()); +// Doubles are rounded to int64 when added to Date +assert.commandWorked(coll.insert({_id: 0, lhs: new Date(1683794065002), rhs: 0.5})); +assert.commandWorked(coll.insert({_id: 1, lhs: new Date(1683794065002), rhs: 1.4})); +assert.commandWorked(coll.insert({_id: 2, lhs: new Date(1683794065002), rhs: 1.5})); +assert.commandWorked(coll.insert({_id: 3, lhs: new Date(1683794065002), rhs: 1.7})); +// Decimals are rounded to int64, when tie rounded to even, when added to Date +assert.commandWorked( + coll.insert({_id: 4, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.4")})); +assert.commandWorked( + coll.insert({_id: 5, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.5")})); +assert.commandWorked( + coll.insert({_id: 6, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.7")})); +assert.commandWorked( + coll.insert({_id: 7, lhs: new Date(1683794065002), rhs: new NumberDecimal("2.5")})); + +let result1 = + coll.aggregate([{$project: {sum: {$add: ["$lhs", "$rhs"]}}}, {$sort: {_id: 1}}]).toArray(); +assert.eq(result1[0].sum, new Date(1683794065003)); +assert.eq(result1[1].sum, new Date(1683794065003)); +assert.eq(result1[2].sum, new Date(1683794065004)); +assert.eq(result1[3].sum, new Date(1683794065004)); +assert.eq(result1[4].sum, new Date(1683794065003)); +assert.eq(result1[5].sum, new Date(1683794065004)); +assert.eq(result1[6].sum, new Date(1683794065004)); +assert.eq(result1[7].sum, new Date(1683794065004)); + +coll.drop(); + +assert.commandWorked(coll.insert([{ + _id: 0, + veryBigPositiveLong: NumberLong("9223372036854775806"), + veryBigPositiveDouble: 9223372036854775806, + veryBigPositiveDecimal: NumberDecimal("9223372036854775806") +}])); + +let pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveLong"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); + +pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveDouble"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); + +pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveDecimal"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); }()); diff --git a/jstests/aggregation/expressions/subtract.js b/jstests/aggregation/expressions/subtract.js index 113635c7b53..34dd2615ade 100644 --- a/jstests/aggregation/expressions/subtract.js +++ b/jstests/aggregation/expressions/subtract.js @@ -1,3 +1,6 @@ +load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and assertErrMsgContains. +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. + // Tests for $subtract aggregation expression (function() { "use strict"; @@ -15,6 +18,20 @@ assert.commandWorked( assert.commandWorked(coll.insert({_id: 5, lhs: new Date(1912392670000), rhs: 70000})); assert.commandWorked( coll.insert({_id: 6, lhs: new Date(1912392670000), rhs: new Date(1912392600000)})); +// Doubles are rounded to int64 when subtracted from Date +assert.commandWorked(coll.insert({_id: 7, lhs: new Date(1683794065002), rhs: 0.5})); +assert.commandWorked(coll.insert({_id: 8, lhs: new Date(1683794065002), rhs: 1.4})); +assert.commandWorked(coll.insert({_id: 9, lhs: new Date(1683794065002), rhs: 1.5})); +assert.commandWorked(coll.insert({_id: 10, lhs: new Date(1683794065002), rhs: 1.7})); +// Decimals are rounded to int64, when tie rounded to even, when subtracted from Date +assert.commandWorked( + coll.insert({_id: 11, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.4")})); +assert.commandWorked( + coll.insert({_id: 12, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.5")})); +assert.commandWorked( + coll.insert({_id: 13, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.7")})); +assert.commandWorked( + coll.insert({_id: 14, lhs: new Date(1683794065002), rhs: new NumberDecimal("2.5")})); const result = coll.aggregate([{$project: {diff: {$subtract: ["$lhs", "$rhs"]}}}, {$sort: {_id: 1}}]) @@ -26,4 +43,31 @@ assert.eq(result[3].diff, 10.0); assert.eq(result[4].diff, NumberDecimal("9990.00005")); assert.eq(result[5].diff, new Date(1912392600000)); assert.eq(result[6].diff, 70000); +assert.eq(result[7].diff, new Date(1683794065001)); +assert.eq(result[8].diff, new Date(1683794065001)); +assert.eq(result[9].diff, new Date(1683794065000)); +assert.eq(result[10].diff, new Date(1683794065000)); +assert.eq(result[11].diff, new Date(1683794065001)); +assert.eq(result[12].diff, new Date(1683794065000)); +assert.eq(result[13].diff, new Date(1683794065000)); +assert.eq(result[14].diff, new Date(1683794065000)); + +// Following cases will report overflow error +coll.drop(); + +assert.commandWorked(coll.insert([{ + _id: 0, + veryBigNegativeLong: NumberLong("-9223372036854775808"), + veryBigNegativeDouble: -9223372036854775808, + veryBigNegativeDecimal: NumberDecimal("-9223372036854775808") +}])); + +let pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeLong"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); + +pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeDouble"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); + +pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeDecimal"]}}}]; +assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow"); }()); diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 83fe6400e20..7176e8a0367 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -429,21 +429,21 @@ private: case NumberInt: case NumberLong: if (overflow::add(longTotal, valToAdd.coerceToLong(), &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; case NumberDouble: { using limits = std::numeric_limits; double doubleToAdd = valToAdd.coerceToDouble(); uassert(ErrorCodes::Overflow, - "date overflow in $add", + "date overflow", // The upper bound is exclusive because it rounds up when it is cast to // a double. doubleToAdd >= static_cast(limits::min()) && doubleToAdd < static_cast(limits::max())); if (overflow::add(longTotal, llround(doubleToAdd), &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; } @@ -454,7 +454,7 @@ private: std::int64_t longToAdd = decimalToAdd.toLong(&signalingFlags); if (signalingFlags != Decimal128::SignalingFlag::kNoFlag || overflow::add(longTotal, longToAdd, &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; } @@ -5691,14 +5691,45 @@ StatusWith ExpressionSubtract::apply(Value lhs, Value rhs) { } else if (lhs.nullish() || rhs.nullish()) { return Value(BSONNULL); } else if (lhs.getType() == Date) { - if (rhs.getType() == Date) { - return Value(durationCount(lhs.getDate() - rhs.getDate())); - } else if (rhs.numeric()) { - return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong())); - } else { - return Status(ErrorCodes::TypeMismatch, - str::stream() - << "can't $subtract " << typeName(rhs.getType()) << " from Date"); + BSONType rhsType = rhs.getType(); + switch (rhsType) { + case Date: + return Value(durationCount(lhs.getDate() - rhs.getDate())); + case NumberInt: + case NumberLong: { + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + if (overflow::sub(longDiff, rhs.coerceToLong(), &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + case NumberDouble: { + using limits = std::numeric_limits; + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + double doubleRhs = rhs.coerceToDouble(); + // check the doubleRhs should not exceed int64 limit and result will not overflow + if (doubleRhs < static_cast(limits::min()) || + doubleRhs >= static_cast(limits::max()) || + overflow::sub(longDiff, llround(doubleRhs), &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + case NumberDecimal: { + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + Decimal128 decimalRhs = rhs.coerceToDecimal(); + std::uint32_t signalingFlags = Decimal128::SignalingFlag::kNoFlag; + std::int64_t longRhs = decimalRhs.toLong(&signalingFlags); + if (signalingFlags != Decimal128::SignalingFlag::kNoFlag || + overflow::sub(longDiff, longRhs, &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + default: + return Status(ErrorCodes::TypeMismatch, + str::stream() + << "can't $subtract " << typeName(rhs.getType()) << " from Date"); } } else { return Status(ErrorCodes::TypeMismatch, -- cgit v1.2.1