diff options
author | Daniel Segel <daniel_segel@brown.edu> | 2022-06-21 19:44:19 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2022-07-25 19:31:06 +0000 |
commit | a2af071c5a28f2d92179ac20ac2d163043db19b9 (patch) | |
tree | 555c79a55d4237b941cc5aaffcda131fa997e65f | |
parent | f442499603a44ed4b3ba7ff7bd57e6fcc2470bdb (diff) | |
download | mongo-a2af071c5a28f2d92179ac20ac2d163043db19b9.tar.gz |
SERVER-67318 Add addition, subtraction, multiplication to const_eval with overflow
-rw-r--r-- | src/mongo/db/exec/sbe/SConscript | 2 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.h | 25 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/optimizer_test.cpp | 99 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/rewrites/const_eval.cpp | 155 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/syntax/expr.cpp | 20 | ||||
-rw-r--r-- | src/mongo/db/query/optimizer/syntax/expr.h | 7 |
6 files changed, 230 insertions, 78 deletions
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript index 9805906a59c..532c89d7f21 100644 --- a/src/mongo/db/exec/sbe/SConscript +++ b/src/mongo/db/exec/sbe/SConscript @@ -18,6 +18,7 @@ env.Library( 'values/bson.cpp', 'values/value.cpp', 'values/value_printer.cpp', + 'vm/arith.cpp', ], LIBDEPS=[ '$BUILD_DIR/mongo/base', @@ -47,7 +48,6 @@ sbeEnv.Library( 'values/sbe_pattern_value_cmp.cpp', 'values/slot.cpp', 'values/slot_printer.cpp', - 'vm/arith.cpp', 'vm/datetime.cpp', 'vm/vm.cpp', ], diff --git a/src/mongo/db/exec/sbe/vm/vm.h b/src/mongo/db/exec/sbe/vm/vm.h index 2fec8265bfd..0f02a4d1121 100644 --- a/src/mongo/db/exec/sbe/vm/vm.h +++ b/src/mongo/db/exec/sbe/vm/vm.h @@ -828,6 +828,19 @@ public: std::tuple<uint8_t, value::TypeTags, value::Value> run(const CodeFragment* code); bool runPredicate(const CodeFragment* code); + static std::tuple<bool, value::TypeTags, value::Value> genericAdd(value::TypeTags lhsTag, + value::Value lhsValue, + value::TypeTags rhsTag, + value::Value rhsValue); + static std::tuple<bool, value::TypeTags, value::Value> genericSub(value::TypeTags lhsTag, + value::Value lhsValue, + value::TypeTags rhsTag, + value::Value rhsValue); + static std::tuple<bool, value::TypeTags, value::Value> genericMul(value::TypeTags lhsTag, + value::Value lhsValue, + value::TypeTags rhsTag, + value::Value rhsValue); + private: // The VM stack is used to pass inputs to instructions and hold the outputs produced by // instructions. Each element of the VM stack is 3-tuple comprised of a boolean ('owned'), @@ -915,18 +928,6 @@ private: std::tuple<bool, value::TypeTags, value::Value> runLambdaInternal(const CodeFragment* code, int64_t position); - std::tuple<bool, value::TypeTags, value::Value> genericAdd(value::TypeTags lhsTag, - value::Value lhsValue, - value::TypeTags rhsTag, - value::Value rhsValue); - std::tuple<bool, value::TypeTags, value::Value> genericSub(value::TypeTags lhsTag, - value::Value lhsValue, - value::TypeTags rhsTag, - value::Value rhsValue); - std::tuple<bool, value::TypeTags, value::Value> genericMul(value::TypeTags lhsTag, - value::Value lhsValue, - value::TypeTags rhsTag, - value::Value rhsValue); std::tuple<bool, value::TypeTags, value::Value> genericDiv(value::TypeTags lhsTag, value::Value lhsValue, value::TypeTags rhsTag, diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp index 52711f1f137..5e1534fe936 100644 --- a/src/mongo/db/query/optimizer/optimizer_test.cpp +++ b/src/mongo/db/query/optimizer/optimizer_test.cpp @@ -31,6 +31,8 @@ #include "mongo/db/query/optimizer/node.h" #include "mongo/db/query/optimizer/reference_tracker.h" #include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/query/optimizer/syntax/syntax.h" +#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h" #include "mongo/db/query/optimizer/utils/unit_test_utils.h" #include "mongo/db/query/optimizer/utils/utils.h" #include "mongo/unittest/unittest.h" @@ -38,44 +40,105 @@ namespace mongo::optimizer { namespace { -TEST(Optimizer, ConstEval) { - // 1 + 2 - auto tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)); - // Run the evaluator. +Constant* constEval(ABT& tree) { auto env = VariableEnvironment::build(tree); ConstEval evaluator{env}; evaluator.optimize(tree); // The result must be Constant. - auto result = tree.cast<Constant>(); + Constant* result = tree.cast<Constant>(); ASSERT(result != nullptr); - // And the value must be 3 (i.e. 1+2). - ASSERT_EQ(result->getValueInt64(), 3); - ASSERT_NE(ABT::tagOf<Constant>(), ABT::tagOf<BinaryOp>()); ASSERT_EQ(tree.tagOf(), ABT::tagOf<Constant>()); + return result; +} + +TEST(Optimizer, ConstEval) { + // 1 + 2 + ABT tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)); + Constant* result = constEval(tree); + ASSERT_EQ(result->getValueInt64(), 3); } + TEST(Optimizer, ConstEvalCompose) { // (1 + 2) + 3 - auto tree = + ABT tree = make<BinaryOp>(Operations::Add, make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)), Constant::int64(3)); + Constant* result = constEval(tree); + ASSERT_EQ(result->getValueInt64(), 6); +} - // Run the evaluator. - auto env = VariableEnvironment::build(tree); - ConstEval evaluator{env}; - evaluator.optimize(tree); - // The result must be Constant. - auto result = tree.cast<Constant>(); - ASSERT(result != nullptr); +TEST(Optimizer, ConstEvalCompose2) { + // 3 - (5 - 4) + auto tree = + make<BinaryOp>(Operations::Sub, + Constant::int64(3), + make<BinaryOp>(Operations::Sub, Constant::int64(5), Constant::int64(4))); + Constant* result = constEval(tree); + ASSERT_EQ(result->getValueInt64(), 2); +} - // And the value must be 6 (i.e. 1+2+3). - ASSERT_EQ(result->getValueInt64(), 6); +TEST(Optimizer, ConstEval3) { + // 1.5 + 0.5 + auto tree = + make<BinaryOp>(Operations::Add, Constant::fromDouble(1.5), Constant::fromDouble(0.5)); + Constant* result = constEval(tree); + ASSERT_EQ(result->getValueDouble(), 2.0); +} + +TEST(Optimizer, ConstEval4) { + // INT32_MAX (as int) + 0 (as double) => INT32_MAX (as double) + auto tree = + make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::fromDouble(0)); + Constant* result = constEval(tree); + ASSERT_EQ(result->getValueDouble(), INT32_MAX); +} + +TEST(Optimizer, ConstEval5) { + // -1 + -2 + ABT tree1 = make<BinaryOp>(Operations::Add, Constant::int32(-1), Constant::int32(-2)); + ASSERT_EQ(constEval(tree1)->getValueInt32(), -3); + // 1 + -1 + ABT tree2 = make<BinaryOp>(Operations::Add, Constant::int32(1), Constant::int32(-1)); + ASSERT_EQ(constEval(tree2)->getValueInt32(), 0); + // 1 + INT32_MIN + ABT tree3 = make<BinaryOp>(Operations::Add, Constant::int32(1), Constant::int32(INT32_MIN)); + ASSERT_EQ(constEval(tree3)->getValueInt32(), -2147483647); +} + +TEST(Optimizer, ConstEval6) { + // -1 * -2 + ABT tree1 = make<BinaryOp>(Operations::Mult, Constant::int32(-1), Constant::int32(-2)); + ASSERT_EQ(constEval(tree1)->getValueInt32(), 2); + // 1 * -1 + ABT tree2 = make<BinaryOp>(Operations::Mult, Constant::int32(1), Constant::int32(-1)); + ASSERT_EQ(constEval(tree2)->getValueInt32(), -1); + // 2 * INT32_MAX + ABT tree3 = make<BinaryOp>(Operations::Mult, Constant::int32(2), Constant::int32(INT32_MAX)); + ASSERT_EQ(constEval(tree3)->getValueInt64(), 4294967294); +} + + +TEST(Optimizer, IntegerOverflow) { + auto int32tree = + make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::int32(1)); + ASSERT_EQ(constEval(int32tree)->getValueInt64(), 2147483648); +} + +TEST(Optimizer, IntegerUnderflow) { + auto int32tree = + make<BinaryOp>(Operations::Add, Constant::int32(INT32_MIN), Constant::int32(-1)); + ASSERT_EQ(constEval(int32tree)->getValueInt64(), -2147483649); + + auto tree = + make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::int64(INT64_MIN)); + ASSERT_EQ(constEval(tree)->getValueInt64(), -9223372034707292161); } TEST(Optimizer, Tracker1) { diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp index 0278e20700e..86ae4add4d6 100644 --- a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp +++ b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp @@ -28,7 +28,15 @@ */ #include "mongo/db/query/optimizer/rewrites/const_eval.h" +#include "mongo/db/exec/sbe/values/value.h" +#include "mongo/db/exec/sbe/vm/vm.h" #include "mongo/db/query/optimizer/utils/utils.h" +#include "mongo/platform/decimal128.h" +#include "mongo/util/assert_util.h" +#include <bits/floatn-common.h> +#include <cfloat> +#include <climits> +#include <cstdint> namespace mongo::optimizer { bool ConstEval::optimize(ABT& n) { @@ -191,24 +199,83 @@ namespace fold_helpers { using namespace sbe::value; template <class T> -sbe::value::Value constFoldNumberHelper(const sbe::value::TypeTags lhsTag, - const sbe::value::Value lhsValue, - const TypeTags rhsTag, - const sbe::value::Value rhsValue) { +sbe::value::Value constFoldNumberAdd(const sbe::value::TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { const auto result = numericCast<T>(lhsTag, lhsValue) + numericCast<T>(rhsTag, rhsValue); return bitcastFrom<T>(result); } template <> -sbe::value::Value constFoldNumberHelper<Decimal128>(const TypeTags lhsTag, - const sbe::value::Value lhsValue, - const TypeTags rhsTag, - const sbe::value::Value rhsValue) { - const auto result = +sbe::value::Value constFoldNumberAdd<Decimal128>(const TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const Decimal128 result = numericCast<Decimal128>(lhsTag, lhsValue).add(numericCast<Decimal128>(rhsTag, rhsValue)); return makeCopyDecimal(result).second; } +template <class T> +sbe::value::Value constFoldNumberSubtract(const sbe::value::TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const auto result = numericCast<T>(lhsTag, lhsValue) - numericCast<T>(rhsTag, rhsValue); + return bitcastFrom<T>(result); +} + +template <> +sbe::value::Value constFoldNumberSubtract<Decimal128>(const TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const Decimal128 result = numericCast<Decimal128>(lhsTag, lhsValue) + .subtract(numericCast<Decimal128>(rhsTag, rhsValue)); + return makeCopyDecimal(result).second; +} + +template <class T> +sbe::value::Value constFoldNumberMult(const sbe::value::TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const auto result = numericCast<T>(lhsTag, lhsValue) * numericCast<T>(rhsTag, rhsValue); + return bitcastFrom<T>(result); +} + +template <> +sbe::value::Value constFoldNumberMult<Decimal128>(const TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue) { + const Decimal128 result = numericCast<Decimal128>(lhsTag, lhsValue) + .multiply(numericCast<Decimal128>(rhsTag, rhsValue)); + return makeCopyDecimal(result).second; +} + +// Checks for Integer Overflow and Underflow +bool willOverflow(const TypeTags lhsTag, + const sbe::value::Value lhsValue, + const TypeTags rhsTag, + const sbe::value::Value rhsValue, + const TypeTags resultType) { + if (resultType == TypeTags::NumberInt32) { + int castedLHS = numericCast<int32_t>(resultType, lhsValue); + int castedRHS = numericCast<int32_t>(resultType, rhsValue); + return (castedLHS > 0 && castedRHS > (INT32_MAX - castedLHS)) || + (castedLHS < 0 && castedRHS < (INT32_MIN - castedLHS)); + } else if (resultType == TypeTags::NumberInt64) { + long castedLHS = numericCast<int64_t>(resultType, lhsValue); + long castedRHS = numericCast<int64_t>(resultType, rhsValue); + return (castedLHS > 0 && castedRHS > (INT64_MAX - castedLHS)) || + (castedLHS < 0 && castedRHS < (INT64_MIN - castedLHS)); + } else { + MONGO_UNREACHABLE; + } +} + } // namespace fold_helpers // Specific transport for binary operation @@ -221,52 +288,46 @@ void ConstEval::transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs) { case Operations::Add: { // Let say we want to recognize ConstLhs + ConstRhs and replace it with the result of // addition. + Constant* lhsConst = lhs.cast<Constant>(); + Constant* rhsConst = rhs.cast<Constant>(); + if (lhsConst && rhsConst) { + auto [lhsTag, lhsValue] = lhsConst->get(); + auto [rhsTag, rhsValue] = rhsConst->get(); + auto [_, resultType, resultValue] = + sbe::vm::ByteCode::genericAdd(lhsTag, lhsValue, rhsTag, rhsValue); + swapAndUpdate(n, make<Constant>(resultType, resultValue)); + } + break; + } + + case Operations::Sub: { + // Let say we want to recognize ConstLhs - ConstRhs and replace it with the result of + // subtraction. auto lhsConst = lhs.cast<Constant>(); auto rhsConst = rhs.cast<Constant>(); if (lhsConst && rhsConst) { auto [lhsTag, lhsValue] = lhsConst->get(); auto [rhsTag, rhsValue] = rhsConst->get(); + auto [_, resultType, resultValue] = + sbe::vm::ByteCode::genericSub(lhsTag, lhsValue, rhsTag, rhsValue); + swapAndUpdate(n, make<Constant>(resultType, resultValue)); + } + break; + } - if (isNumber(lhsTag) && isNumber(rhsTag)) { - // So this is the addition operation and both arguments are number constants, - // hence we can compute the result. - - const TypeTags resultType = getWidestNumericalType(lhsTag, rhsTag); - sbe::value::Value resultValue; - - switch (resultType) { - case TypeTags::NumberInt32: { - resultValue = - constFoldNumberHelper<int32_t>(lhsTag, lhsValue, rhsTag, rhsValue); - break; - } - - case TypeTags::NumberInt64: { - resultValue = - constFoldNumberHelper<int64_t>(lhsTag, lhsValue, rhsTag, rhsValue); - break; - } - - case TypeTags::NumberDouble: { - resultValue = - constFoldNumberHelper<double>(lhsTag, lhsValue, rhsTag, rhsValue); - break; - } - - case TypeTags::NumberDecimal: { - resultValue = constFoldNumberHelper<Decimal128>( - lhsTag, lhsValue, rhsTag, rhsValue); - break; - } - - default: - MONGO_UNREACHABLE; - } + case Operations::Mult: { + // Let say we want to recognize ConstLhs * ConstRhs and replace it with the result of + // multiplication. + auto lhsConst = lhs.cast<Constant>(); + auto rhsConst = rhs.cast<Constant>(); - // And this is the crucial step - we swap the current node (n) for the result. - swapAndUpdate(n, make<Constant>(resultType, resultValue)); - } + if (lhsConst && rhsConst) { + auto [lhsTag, lhsValue] = lhsConst->get(); + auto [rhsTag, rhsValue] = rhsConst->get(); + auto [_, resultType, resultValue] = + sbe::vm::ByteCode::genericMul(lhsTag, lhsValue, rhsTag, rhsValue); + swapAndUpdate(n, make<Constant>(resultType, resultValue)); } break; } diff --git a/src/mongo/db/query/optimizer/syntax/expr.cpp b/src/mongo/db/query/optimizer/syntax/expr.cpp index abfcbb6d5f2..dd7e0258816 100644 --- a/src/mongo/db/query/optimizer/syntax/expr.cpp +++ b/src/mongo/db/query/optimizer/syntax/expr.cpp @@ -29,6 +29,7 @@ #include "mongo/db/query/optimizer/syntax/expr.h" #include "mongo/db/query/optimizer/node.h" +#include "mongo/platform/decimal128.h" namespace mongo::optimizer { @@ -141,4 +142,23 @@ int32_t Constant::getValueInt32() const { return bitcastTo<int32_t>(_val); } +bool Constant::isValueDouble() const { + return _tag == TypeTags::NumberDouble; +} + +double Constant::getValueDouble() const { + uassert(673180, "Constant value type is not double", isValueDouble()); + return bitcastTo<double>(_val); +} + +bool Constant::isValueDecimal() const { + return _tag == TypeTags::NumberDecimal; +} + +Decimal128 Constant::getValueDecimal() const { + uassert(673181, "Constant value type is not Decimal128", isValueDecimal()); + return bitcastTo<Decimal128>(_val); +} + + } // namespace mongo::optimizer diff --git a/src/mongo/db/query/optimizer/syntax/expr.h b/src/mongo/db/query/optimizer/syntax/expr.h index 652fbf1c3fb..5fb70be4445 100644 --- a/src/mongo/db/query/optimizer/syntax/expr.h +++ b/src/mongo/db/query/optimizer/syntax/expr.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/platform/decimal128.h" #include <ostream> #include "mongo/db/exec/sbe/values/value.h" @@ -81,6 +82,12 @@ public: bool isValueInt32() const; int32_t getValueInt32() const; + bool isValueDouble() const; + double getValueDouble() const; + + bool isValueDecimal() const; + Decimal128 getValueDecimal() const; + bool isNumber() const { return sbe::value::isNumber(_tag); } |