summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/expression.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/pipeline/expression.cpp')
-rw-r--r--src/mongo/db/pipeline/expression.cpp55
1 files changed, 43 insertions, 12 deletions
diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp
index 83fe6400e20..7176e8a0367 100644
--- a/src/mongo/db/pipeline/expression.cpp
+++ b/src/mongo/db/pipeline/expression.cpp
@@ -429,21 +429,21 @@ private:
case NumberInt:
case NumberLong:
if (overflow::add(longTotal, valToAdd.coerceToLong(), &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
case NumberDouble: {
using limits = std::numeric_limits<long long>;
double doubleToAdd = valToAdd.coerceToDouble();
uassert(ErrorCodes::Overflow,
- "date overflow in $add",
+ "date overflow",
// The upper bound is exclusive because it rounds up when it is cast to
// a double.
doubleToAdd >= static_cast<double>(limits::min()) &&
doubleToAdd < static_cast<double>(limits::max()));
if (overflow::add(longTotal, llround(doubleToAdd), &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
}
@@ -454,7 +454,7 @@ private:
std::int64_t longToAdd = decimalToAdd.toLong(&signalingFlags);
if (signalingFlags != Decimal128::SignalingFlag::kNoFlag ||
overflow::add(longTotal, longToAdd, &longTotal)) {
- uasserted(ErrorCodes::Overflow, "date overflow in $add");
+ uasserted(ErrorCodes::Overflow, "date overflow");
}
break;
}
@@ -5691,14 +5691,45 @@ StatusWith<Value> ExpressionSubtract::apply(Value lhs, Value rhs) {
} else if (lhs.nullish() || rhs.nullish()) {
return Value(BSONNULL);
} else if (lhs.getType() == Date) {
- if (rhs.getType() == Date) {
- return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
- } else if (rhs.numeric()) {
- return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong()));
- } else {
- return Status(ErrorCodes::TypeMismatch,
- str::stream()
- << "can't $subtract " << typeName(rhs.getType()) << " from Date");
+ BSONType rhsType = rhs.getType();
+ switch (rhsType) {
+ case Date:
+ return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
+ case NumberInt:
+ case NumberLong: {
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ if (overflow::sub(longDiff, rhs.coerceToLong(), &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ case NumberDouble: {
+ using limits = std::numeric_limits<long long>;
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ double doubleRhs = rhs.coerceToDouble();
+ // check the doubleRhs should not exceed int64 limit and result will not overflow
+ if (doubleRhs < static_cast<double>(limits::min()) ||
+ doubleRhs >= static_cast<double>(limits::max()) ||
+ overflow::sub(longDiff, llround(doubleRhs), &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ case NumberDecimal: {
+ long long longDiff = lhs.getDate().toMillisSinceEpoch();
+ Decimal128 decimalRhs = rhs.coerceToDecimal();
+ std::uint32_t signalingFlags = Decimal128::SignalingFlag::kNoFlag;
+ std::int64_t longRhs = decimalRhs.toLong(&signalingFlags);
+ if (signalingFlags != Decimal128::SignalingFlag::kNoFlag ||
+ overflow::sub(longDiff, longRhs, &longDiff)) {
+ return Status(ErrorCodes::Overflow, str::stream() << "date overflow");
+ }
+ return Value(Date_t::fromMillisSinceEpoch(longDiff));
+ }
+ default:
+ return Status(ErrorCodes::TypeMismatch,
+ str::stream()
+ << "can't $subtract " << typeName(rhs.getType()) << " from Date");
}
} else {
return Status(ErrorCodes::TypeMismatch,