diff options
author | Rushan Chen <rushan.chen@mongodb.com> | 2023-05-17 09:22:28 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2023-05-17 10:22:40 +0000 |
commit | 1b4a551a6b8c85611e26857217ce1a1e1363e716 (patch) | |
tree | 3a6159a18a979749843592e34e06343b96d35050 /src/mongo | |
parent | f332018330ee82ec6e240a81896d8432e37609da (diff) | |
download | mongo-1b4a551a6b8c85611e26857217ce1a1e1363e716.tar.gz |
SERVER-75287: fix rounding and overflow detection in Classic for subtract when one operand is Date and another numeric
Diffstat (limited to 'src/mongo')
-rw-r--r-- | src/mongo/db/pipeline/expression.cpp | 55 |
1 files changed, 43 insertions, 12 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index 83fe6400e20..7176e8a0367 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -429,21 +429,21 @@ private: case NumberInt: case NumberLong: if (overflow::add(longTotal, valToAdd.coerceToLong(), &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; case NumberDouble: { using limits = std::numeric_limits<long long>; double doubleToAdd = valToAdd.coerceToDouble(); uassert(ErrorCodes::Overflow, - "date overflow in $add", + "date overflow", // The upper bound is exclusive because it rounds up when it is cast to // a double. doubleToAdd >= static_cast<double>(limits::min()) && doubleToAdd < static_cast<double>(limits::max())); if (overflow::add(longTotal, llround(doubleToAdd), &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; } @@ -454,7 +454,7 @@ private: std::int64_t longToAdd = decimalToAdd.toLong(&signalingFlags); if (signalingFlags != Decimal128::SignalingFlag::kNoFlag || overflow::add(longTotal, longToAdd, &longTotal)) { - uasserted(ErrorCodes::Overflow, "date overflow in $add"); + uasserted(ErrorCodes::Overflow, "date overflow"); } break; } @@ -5691,14 +5691,45 @@ StatusWith<Value> ExpressionSubtract::apply(Value lhs, Value rhs) { } else if (lhs.nullish() || rhs.nullish()) { return Value(BSONNULL); } else if (lhs.getType() == Date) { - if (rhs.getType() == Date) { - return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate())); - } else if (rhs.numeric()) { - return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong())); - } else { - return Status(ErrorCodes::TypeMismatch, - str::stream() - << "can't $subtract " << typeName(rhs.getType()) << " from Date"); + BSONType rhsType = rhs.getType(); + switch (rhsType) { + case Date: + return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate())); + case NumberInt: + case NumberLong: { + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + if (overflow::sub(longDiff, rhs.coerceToLong(), &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + case NumberDouble: { + using limits = std::numeric_limits<long long>; + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + double doubleRhs = rhs.coerceToDouble(); + // check the doubleRhs should not exceed int64 limit and result will not overflow + if (doubleRhs < static_cast<double>(limits::min()) || + doubleRhs >= static_cast<double>(limits::max()) || + overflow::sub(longDiff, llround(doubleRhs), &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + case NumberDecimal: { + long long longDiff = lhs.getDate().toMillisSinceEpoch(); + Decimal128 decimalRhs = rhs.coerceToDecimal(); + std::uint32_t signalingFlags = Decimal128::SignalingFlag::kNoFlag; + std::int64_t longRhs = decimalRhs.toLong(&signalingFlags); + if (signalingFlags != Decimal128::SignalingFlag::kNoFlag || + overflow::sub(longDiff, longRhs, &longDiff)) { + return Status(ErrorCodes::Overflow, str::stream() << "date overflow"); + } + return Value(Date_t::fromMillisSinceEpoch(longDiff)); + } + default: + return Status(ErrorCodes::TypeMismatch, + str::stream() + << "can't $subtract " << typeName(rhs.getType()) << " from Date"); } } else { return Status(ErrorCodes::TypeMismatch, |