summaryrefslogtreecommitdiff
path: root/src/mongo/util/safe_num.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/util/safe_num.cpp')
-rw-r--r--src/mongo/util/safe_num.cpp50
1 files changed, 50 insertions, 0 deletions
diff --git a/src/mongo/util/safe_num.cpp b/src/mongo/util/safe_num.cpp
index 8976f27a251..f49fcd7b53a 100644
--- a/src/mongo/util/safe_num.cpp
+++ b/src/mongo/util/safe_num.cpp
@@ -51,6 +51,10 @@ SafeNum::SafeNum(const BSONElement& element) {
_type = NumberDouble;
_value.doubleVal = element.Double();
break;
+ case NumberDecimal:
+ _type = NumberDecimal;
+ _value.decimalVal = element.Decimal().getValue();
+ break;
default:
_type = EOO;
}
@@ -68,6 +72,9 @@ std::string SafeNum::debugString() const {
case NumberDouble:
os << "(NumberDouble)" << _value.doubleVal;
break;
+ case NumberDecimal:
+ os << "(NumberDecimal)" << getDecimal(*this).toString();
+ break;
case EOO:
os << "(EOO)";
break;
@@ -99,6 +106,14 @@ bool SafeNum::isEquivalent(const SafeNum& rhs) const {
// If the types of either side are mixed, we'll try to find the shortest type we
// can upconvert to that would not sacrifice the accuracy in the process.
+ // If one side is a decimal, compare both sides as decimals.
+ if (_type == NumberDecimal || rhs._type == NumberDecimal) {
+ // Note: isEqual is faster than using compareDecimals, however it does not handle
+ // comparing NaN as equal (differing from BSONElement::woCompare). This case
+ // is not handled for double comparison above eihter.
+ return getDecimal(*this).isEqual(getDecimal(rhs));
+ }
+
// If none of the sides is a double, compare them as long's.
if (_type != NumberDouble && rhs._type != NumberDouble) {
return getLongLong(*this) == getLongLong(rhs);
@@ -134,6 +149,8 @@ bool SafeNum::isIdentical(const SafeNum& rhs) const {
return _value.int64Val == rhs._value.int64Val;
case NumberDouble:
return _value.doubleVal == rhs._value.doubleVal;
+ case NumberDecimal:
+ return Decimal128(_value.decimalVal).isEqual(rhs._value.decimalVal);
case EOO:
// EOO doesn't match anything, including itself.
default:
@@ -160,6 +177,23 @@ double SafeNum::getDouble(const SafeNum& snum) {
return snum._value.int64Val;
case NumberDouble:
return snum._value.doubleVal;
+ case NumberDecimal:
+ return Decimal128(snum._value.decimalVal).toDouble();
+ default:
+ return 0.0;
+ }
+}
+
+Decimal128 SafeNum::getDecimal(const SafeNum& snum) {
+ switch (snum._type) {
+ case NumberInt:
+ return snum._value.int32Val;
+ case NumberLong:
+ return snum._value.int64Val;
+ case NumberDouble:
+ return snum._value.doubleVal;
+ case NumberDecimal:
+ return snum._value.decimalVal;
default:
return 0.0;
}
@@ -211,6 +245,10 @@ SafeNum addFloats(double lDouble, double rDouble) {
return SafeNum(sum);
}
+SafeNum addDecimals(Decimal128 lDecimal, Decimal128 rDecimal) {
+ return SafeNum(lDecimal.add(rDecimal));
+}
+
SafeNum mulInt32Int32(int lInt32, int rInt32) {
// NOTE: Please see "Secure Coding in C and C++", Second Edition, page 264-265 for
// details on this algorithm (for an alternative resources, see
@@ -274,6 +312,10 @@ SafeNum mulFloats(double lDouble, double rDouble) {
return SafeNum(product);
}
+SafeNum mulDecimals(Decimal128 lDecimal, Decimal128 rDecimal) {
+ return SafeNum(lDecimal.multiply(rDecimal));
+}
+
} // namespace
SafeNum SafeNum::addInternal(const SafeNum& lhs, const SafeNum& rhs) {
@@ -296,6 +338,10 @@ SafeNum SafeNum::addInternal(const SafeNum& lhs, const SafeNum& rhs) {
return addInt64Int64(lhs._value.int64Val, rhs._value.int64Val);
}
+ if (lType == NumberDecimal || rType == NumberDecimal) {
+ return addDecimals(getDecimal(lhs), getDecimal(rhs));
+ }
+
if ((lType == NumberInt || lType == NumberLong || lType == NumberDouble) &&
(rType == NumberInt || rType == NumberLong || rType == NumberDouble)) {
return addFloats(getDouble(lhs), getDouble(rhs));
@@ -324,6 +370,10 @@ SafeNum SafeNum::mulInternal(const SafeNum& lhs, const SafeNum& rhs) {
return mulInt64Int64(lhs._value.int64Val, rhs._value.int64Val);
}
+ if (lType == NumberDecimal || rType == NumberDecimal) {
+ return mulDecimals(getDecimal(lhs), getDecimal(rhs));
+ }
+
if ((lType == NumberInt || lType == NumberLong || lType == NumberDouble) &&
(rType == NumberInt || rType == NumberLong || rType == NumberDouble)) {
return mulFloats(getDouble(lhs), getDouble(rhs));