diff options
Diffstat (limited to 'src/mongo/util/safe_num.cpp')
-rw-r--r-- | src/mongo/util/safe_num.cpp | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/src/mongo/util/safe_num.cpp b/src/mongo/util/safe_num.cpp index 8976f27a251..f49fcd7b53a 100644 --- a/src/mongo/util/safe_num.cpp +++ b/src/mongo/util/safe_num.cpp @@ -51,6 +51,10 @@ SafeNum::SafeNum(const BSONElement& element) { _type = NumberDouble; _value.doubleVal = element.Double(); break; + case NumberDecimal: + _type = NumberDecimal; + _value.decimalVal = element.Decimal().getValue(); + break; default: _type = EOO; } @@ -68,6 +72,9 @@ std::string SafeNum::debugString() const { case NumberDouble: os << "(NumberDouble)" << _value.doubleVal; break; + case NumberDecimal: + os << "(NumberDecimal)" << getDecimal(*this).toString(); + break; case EOO: os << "(EOO)"; break; @@ -99,6 +106,14 @@ bool SafeNum::isEquivalent(const SafeNum& rhs) const { // If the types of either side are mixed, we'll try to find the shortest type we // can upconvert to that would not sacrifice the accuracy in the process. + // If one side is a decimal, compare both sides as decimals. + if (_type == NumberDecimal || rhs._type == NumberDecimal) { + // Note: isEqual is faster than using compareDecimals, however it does not handle + // comparing NaN as equal (differing from BSONElement::woCompare). This case + // is not handled for double comparison above eihter. + return getDecimal(*this).isEqual(getDecimal(rhs)); + } + // If none of the sides is a double, compare them as long's. if (_type != NumberDouble && rhs._type != NumberDouble) { return getLongLong(*this) == getLongLong(rhs); @@ -134,6 +149,8 @@ bool SafeNum::isIdentical(const SafeNum& rhs) const { return _value.int64Val == rhs._value.int64Val; case NumberDouble: return _value.doubleVal == rhs._value.doubleVal; + case NumberDecimal: + return Decimal128(_value.decimalVal).isEqual(rhs._value.decimalVal); case EOO: // EOO doesn't match anything, including itself. default: @@ -160,6 +177,23 @@ double SafeNum::getDouble(const SafeNum& snum) { return snum._value.int64Val; case NumberDouble: return snum._value.doubleVal; + case NumberDecimal: + return Decimal128(snum._value.decimalVal).toDouble(); + default: + return 0.0; + } +} + +Decimal128 SafeNum::getDecimal(const SafeNum& snum) { + switch (snum._type) { + case NumberInt: + return snum._value.int32Val; + case NumberLong: + return snum._value.int64Val; + case NumberDouble: + return snum._value.doubleVal; + case NumberDecimal: + return snum._value.decimalVal; default: return 0.0; } @@ -211,6 +245,10 @@ SafeNum addFloats(double lDouble, double rDouble) { return SafeNum(sum); } +SafeNum addDecimals(Decimal128 lDecimal, Decimal128 rDecimal) { + return SafeNum(lDecimal.add(rDecimal)); +} + SafeNum mulInt32Int32(int lInt32, int rInt32) { // NOTE: Please see "Secure Coding in C and C++", Second Edition, page 264-265 for // details on this algorithm (for an alternative resources, see @@ -274,6 +312,10 @@ SafeNum mulFloats(double lDouble, double rDouble) { return SafeNum(product); } +SafeNum mulDecimals(Decimal128 lDecimal, Decimal128 rDecimal) { + return SafeNum(lDecimal.multiply(rDecimal)); +} + } // namespace SafeNum SafeNum::addInternal(const SafeNum& lhs, const SafeNum& rhs) { @@ -296,6 +338,10 @@ SafeNum SafeNum::addInternal(const SafeNum& lhs, const SafeNum& rhs) { return addInt64Int64(lhs._value.int64Val, rhs._value.int64Val); } + if (lType == NumberDecimal || rType == NumberDecimal) { + return addDecimals(getDecimal(lhs), getDecimal(rhs)); + } + if ((lType == NumberInt || lType == NumberLong || lType == NumberDouble) && (rType == NumberInt || rType == NumberLong || rType == NumberDouble)) { return addFloats(getDouble(lhs), getDouble(rhs)); @@ -324,6 +370,10 @@ SafeNum SafeNum::mulInternal(const SafeNum& lhs, const SafeNum& rhs) { return mulInt64Int64(lhs._value.int64Val, rhs._value.int64Val); } + if (lType == NumberDecimal || rType == NumberDecimal) { + return mulDecimals(getDecimal(lhs), getDecimal(rhs)); + } + if ((lType == NumberInt || lType == NumberLong || lType == NumberDouble) && (rType == NumberInt || rType == NumberLong || rType == NumberDouble)) { return mulFloats(getDouble(lhs), getDouble(rhs)); |