summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRushan Chen <rushan.chen@mongodb.com>2023-05-17 09:22:28 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-05-17 10:22:40 +0000
commit1b4a551a6b8c85611e26857217ce1a1e1363e716 (patch)
tree3a6159a18a979749843592e34e06343b96d35050
parentf332018330ee82ec6e240a81896d8432e37609da (diff)
downloadmongo-1b4a551a6b8c85611e26857217ce1a1e1363e716.tar.gz
SERVER-75287: fix rounding and overflow detection in Classic for subtract when one operand is Date and another numeric
-rw-r--r--jstests/aggregation/expressions/add.js47
-rw-r--r--jstests/aggregation/expressions/subtract.js44
-rw-r--r--src/mongo/db/pipeline/expression.cpp55
3 files changed, 134 insertions, 12 deletions
diff --git a/jstests/aggregation/expressions/add.js b/jstests/aggregation/expressions/add.js
index cc074a9ae08..c411c3f213b 100644
--- a/jstests/aggregation/expressions/add.js
+++ b/jstests/aggregation/expressions/add.js
@@ -1,5 +1,7 @@
(function() {
"use strict";
+load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and assertErrMsgContains.
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
// In SERVER-63012, translation of $add expression into sbe now defaults the translation of $add
// with no operands to a zero integer constant.
@@ -43,4 +45,49 @@ let addResult = coll.aggregate([{$project: {add: {$add: queryArr}}}]).toArray();
let sumResult = coll.aggregate([{$project: {sum: {$sum: queryArr}}}]).toArray();
assert.neq(addResult[0]["add"], sumResult[0]["sum"]);
assert.eq(addResult[0]["add"], arr.reduce((a, b) => a + b));
+
+assert.eq(true, coll.drop());
+// Doubles are rounded to int64 when added to Date
+assert.commandWorked(coll.insert({_id: 0, lhs: new Date(1683794065002), rhs: 0.5}));
+assert.commandWorked(coll.insert({_id: 1, lhs: new Date(1683794065002), rhs: 1.4}));
+assert.commandWorked(coll.insert({_id: 2, lhs: new Date(1683794065002), rhs: 1.5}));
+assert.commandWorked(coll.insert({_id: 3, lhs: new Date(1683794065002), rhs: 1.7}));
+// Decimals are rounded to int64, when tie rounded to even, when added to Date
+assert.commandWorked(
+ coll.insert({_id: 4, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.4")}));
+assert.commandWorked(
+ coll.insert({_id: 5, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.5")}));
+assert.commandWorked(
+ coll.insert({_id: 6, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.7")}));
+assert.commandWorked(
+ coll.insert({_id: 7, lhs: new Date(1683794065002), rhs: new NumberDecimal("2.5")}));
+
+let result1 =
+ coll.aggregate([{$project: {sum: {$add: ["$lhs", "$rhs"]}}}, {$sort: {_id: 1}}]).toArray();
+assert.eq(result1[0].sum, new Date(1683794065003));
+assert.eq(result1[1].sum, new Date(1683794065003));
+assert.eq(result1[2].sum, new Date(1683794065004));
+assert.eq(result1[3].sum, new Date(1683794065004));
+assert.eq(result1[4].sum, new Date(1683794065003));
+assert.eq(result1[5].sum, new Date(1683794065004));
+assert.eq(result1[6].sum, new Date(1683794065004));
+assert.eq(result1[7].sum, new Date(1683794065004));
+
+coll.drop();
+
+assert.commandWorked(coll.insert([{
+ _id: 0,
+ veryBigPositiveLong: NumberLong("9223372036854775806"),
+ veryBigPositiveDouble: 9223372036854775806,
+ veryBigPositiveDecimal: NumberDecimal("9223372036854775806")
+}]));
+
+let pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveLong"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
+
+pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveDouble"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
+
+pipeline = [{$project: {res: {$add: [new Date(10), "$veryBigPositiveDecimal"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
}());
diff --git a/jstests/aggregation/expressions/subtract.js b/jstests/aggregation/expressions/subtract.js
index 113635c7b53..34dd2615ade 100644
--- a/jstests/aggregation/expressions/subtract.js
+++ b/jstests/aggregation/expressions/subtract.js
@@ -1,3 +1,6 @@
+load("jstests/aggregation/extras/utils.js"); // For assertErrorCode and assertErrMsgContains.
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
+
// Tests for $subtract aggregation expression
(function() {
"use strict";
@@ -15,6 +18,20 @@ assert.commandWorked(
assert.commandWorked(coll.insert({_id: 5, lhs: new Date(1912392670000), rhs: 70000}));
assert.commandWorked(
coll.insert({_id: 6, lhs: new Date(1912392670000), rhs: new Date(1912392600000)}));
+// Doubles are rounded to int64 when subtracted from Date
+assert.commandWorked(coll.insert({_id: 7, lhs: new Date(1683794065002), rhs: 0.5}));
+assert.commandWorked(coll.insert({_id: 8, lhs: new Date(1683794065002), rhs: 1.4}));
+assert.commandWorked(coll.insert({_id: 9, lhs: new Date(1683794065002), rhs: 1.5}));
+assert.commandWorked(coll.insert({_id: 10, lhs: new Date(1683794065002), rhs: 1.7}));
+// Decimals are rounded to int64, when tie rounded to even, when subtracted from Date
+assert.commandWorked(
+ coll.insert({_id: 11, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.4")}));
+assert.commandWorked(
+ coll.insert({_id: 12, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.5")}));
+assert.commandWorked(
+ coll.insert({_id: 13, lhs: new Date(1683794065002), rhs: new NumberDecimal("1.7")}));
+assert.commandWorked(
+ coll.insert({_id: 14, lhs: new Date(1683794065002), rhs: new NumberDecimal("2.5")}));
const result =
coll.aggregate([{$project: {diff: {$subtract: ["$lhs", "$rhs"]}}}, {$sort: {_id: 1}}])
@@ -26,4 +43,31 @@ assert.eq(result[3].diff, 10.0);
assert.eq(result[4].diff, NumberDecimal("9990.00005"));
assert.eq(result[5].diff, new Date(1912392600000));
assert.eq(result[6].diff, 70000);
+assert.eq(result[7].diff, new Date(1683794065001));
+assert.eq(result[8].diff, new Date(1683794065001));
+assert.eq(result[9].diff, new Date(1683794065000));
+assert.eq(result[10].diff, new Date(1683794065000));
+assert.eq(result[11].diff, new Date(1683794065001));
+assert.eq(result[12].diff, new Date(1683794065000));
+assert.eq(result[13].diff, new Date(1683794065000));
+assert.eq(result[14].diff, new Date(1683794065000));
+
+// Following cases will report overflow error
+coll.drop();
+
+assert.commandWorked(coll.insert([{
+ _id: 0,
+ veryBigNegativeLong: NumberLong("-9223372036854775808"),
+ veryBigNegativeDouble: -9223372036854775808,
+ veryBigNegativeDecimal: NumberDecimal("-9223372036854775808")
+}]));
+
+let pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeLong"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
+
+pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeDouble"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
+
+pipeline = [{$project: {res: {$subtract: [new Date(10), "$veryBigNegativeDecimal"]}}}];
+assertErrCodeAndErrMsgContains(coll, pipeline, ErrorCodes.Overflow, "date overflow");
}());
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 83fe6400e20..7176e8a0367 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -429,21 +429,21 @@ private:
case NumberInt:
case NumberLong:
if (overflow::add(longTotal, valToAdd.coerceToLong(), &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
case NumberDouble: {
using limits = std::numeric_limits<long long>;
double doubleToAdd = valToAdd.coerceToDouble();
uassert(ErrorCodes::Overflow,
- "date overflow in $add",
+ "date overflow",
// The upper bound is exclusive because it rounds up when it is cast to
// a double.
doubleToAdd >= static_cast<double>(limits::min()) &&
doubleToAdd < static_cast<double>(limits::max()));
if (overflow::add(longTotal, llround(doubleToAdd), &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
}
@@ -454,7 +454,7 @@ private:
std::int64_t longToAdd = decimalToAdd.toLong(&signalingFlags);
if (signalingFlags != Decimal128::SignalingFlag::kNoFlag ||
overflow::add(longTotal, longToAdd, &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
}
@@ -5691,14 +5691,45 @@ StatusWith<Value> ExpressionSubtract::apply(Value lhs, Value rhs) {
} else if (lhs.nullish() || rhs.nullish()) {
return Value(BSONNULL);
} else if (lhs.getType() == Date) {
- if (rhs.getType() == Date) {
- return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
- } else if (rhs.numeric()) {
- return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong()));
- } else {
- return Status(ErrorCodes::TypeMismatch,
- str::stream()
- << "can't $subtract " << typeName(rhs.getType()) << " from Date");
+ BSONType rhsType = rhs.getType();
+ switch (rhsType) {
+ case Date:
+ return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
+ case NumberInt:
+ case NumberLong: {
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ if (overflow::sub(longDiff, rhs.coerceToLong(), &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ case NumberDouble: {
+ using limits = std::numeric_limits<long long>;
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ double doubleRhs = rhs.coerceToDouble();
+ // check the doubleRhs should not exceed int64 limit and result will not overflow
+ if (doubleRhs < static_cast<double>(limits::min()) ||
+ doubleRhs >= static_cast<double>(limits::max()) ||
+ overflow::sub(longDiff, llround(doubleRhs), &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ case NumberDecimal: {
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ Decimal128 decimalRhs = rhs.coerceToDecimal();
+ std::uint32_t signalingFlags = Decimal128::SignalingFlag::kNoFlag;
+ std::int64_t longRhs = decimalRhs.toLong(&signalingFlags);
+ if (signalingFlags != Decimal128::SignalingFlag::kNoFlag ||
+ overflow::sub(longDiff, longRhs, &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ default:
+ return Status(ErrorCodes::TypeMismatch,
+ str::stream()
+ << "can't $subtract " << typeName(rhs.getType()) << " from Date");
}
} else {
return Status(ErrorCodes::TypeMismatch,