summaryrefslogtreecommitdiff
path: root/src/mongo/db/query/optimizer
diff options
context:
space:
mode:
authorDaniel Segel <daniel_segel@brown.edu>2022-06-21 19:44:19 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2022-07-25 19:31:06 +0000
commita2af071c5a28f2d92179ac20ac2d163043db19b9 (patch)
tree555c79a55d4237b941cc5aaffcda131fa997e65f /src/mongo/db/query/optimizer
parentf442499603a44ed4b3ba7ff7bd57e6fcc2470bdb (diff)
downloadmongo-a2af071c5a28f2d92179ac20ac2d163043db19b9.tar.gz
SERVER-67318 Add addition, subtraction, multiplication to const_eval with overflow
Diffstat (limited to 'src/mongo/db/query/optimizer')
-rw-r--r--src/mongo/db/query/optimizer/optimizer_test.cpp99
-rw-r--r--src/mongo/db/query/optimizer/rewrites/const_eval.cpp155
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.cpp20
-rw-r--r--src/mongo/db/query/optimizer/syntax/expr.h7
4 files changed, 216 insertions, 65 deletions
diff --git a/src/mongo/db/query/optimizer/optimizer_test.cpp b/src/mongo/db/query/optimizer/optimizer_test.cpp
index 52711f1f137..5e1534fe936 100644
--- a/src/mongo/db/query/optimizer/optimizer_test.cpp
+++ b/src/mongo/db/query/optimizer/optimizer_test.cpp
@@ -31,6 +31,8 @@
#include "mongo/db/query/optimizer/node.h"
#include "mongo/db/query/optimizer/reference_tracker.h"
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/query/optimizer/syntax/syntax.h"
+#include "mongo/db/query/optimizer/syntax/syntax_fwd_declare.h"
#include "mongo/db/query/optimizer/utils/unit_test_utils.h"
#include "mongo/db/query/optimizer/utils/utils.h"
#include "mongo/unittest/unittest.h"
@@ -38,44 +40,105 @@
namespace mongo::optimizer {
namespace {
-TEST(Optimizer, ConstEval) {
- // 1 + 2
- auto tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2));
- // Run the evaluator.
+Constant* constEval(ABT& tree) {
auto env = VariableEnvironment::build(tree);
ConstEval evaluator{env};
evaluator.optimize(tree);
// The result must be Constant.
- auto result = tree.cast<Constant>();
+ Constant* result = tree.cast<Constant>();
ASSERT(result != nullptr);
- // And the value must be 3 (i.e. 1+2).
- ASSERT_EQ(result->getValueInt64(), 3);
-
ASSERT_NE(ABT::tagOf<Constant>(), ABT::tagOf<BinaryOp>());
ASSERT_EQ(tree.tagOf(), ABT::tagOf<Constant>());
+ return result;
+}
+
+TEST(Optimizer, ConstEval) {
+ // 1 + 2
+ ABT tree = make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2));
+ Constant* result = constEval(tree);
+ ASSERT_EQ(result->getValueInt64(), 3);
}
+
TEST(Optimizer, ConstEvalCompose) {
// (1 + 2) + 3
- auto tree =
+ ABT tree =
make<BinaryOp>(Operations::Add,
make<BinaryOp>(Operations::Add, Constant::int64(1), Constant::int64(2)),
Constant::int64(3));
+ Constant* result = constEval(tree);
+ ASSERT_EQ(result->getValueInt64(), 6);
+}
- // Run the evaluator.
- auto env = VariableEnvironment::build(tree);
- ConstEval evaluator{env};
- evaluator.optimize(tree);
- // The result must be Constant.
- auto result = tree.cast<Constant>();
- ASSERT(result != nullptr);
+TEST(Optimizer, ConstEvalCompose2) {
+ // 3 - (5 - 4)
+ auto tree =
+ make<BinaryOp>(Operations::Sub,
+ Constant::int64(3),
+ make<BinaryOp>(Operations::Sub, Constant::int64(5), Constant::int64(4)));
+ Constant* result = constEval(tree);
+ ASSERT_EQ(result->getValueInt64(), 2);
+}
- // And the value must be 6 (i.e. 1+2+3).
- ASSERT_EQ(result->getValueInt64(), 6);
+TEST(Optimizer, ConstEval3) {
+ // 1.5 + 0.5
+ auto tree =
+ make<BinaryOp>(Operations::Add, Constant::fromDouble(1.5), Constant::fromDouble(0.5));
+ Constant* result = constEval(tree);
+ ASSERT_EQ(result->getValueDouble(), 2.0);
+}
+
+TEST(Optimizer, ConstEval4) {
+ // INT32_MAX (as int) + 0 (as double) => INT32_MAX (as double)
+ auto tree =
+ make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::fromDouble(0));
+ Constant* result = constEval(tree);
+ ASSERT_EQ(result->getValueDouble(), INT32_MAX);
+}
+
+TEST(Optimizer, ConstEval5) {
+ // -1 + -2
+ ABT tree1 = make<BinaryOp>(Operations::Add, Constant::int32(-1), Constant::int32(-2));
+ ASSERT_EQ(constEval(tree1)->getValueInt32(), -3);
+ // 1 + -1
+ ABT tree2 = make<BinaryOp>(Operations::Add, Constant::int32(1), Constant::int32(-1));
+ ASSERT_EQ(constEval(tree2)->getValueInt32(), 0);
+ // 1 + INT32_MIN
+ ABT tree3 = make<BinaryOp>(Operations::Add, Constant::int32(1), Constant::int32(INT32_MIN));
+ ASSERT_EQ(constEval(tree3)->getValueInt32(), -2147483647);
+}
+
+TEST(Optimizer, ConstEval6) {
+ // -1 * -2
+ ABT tree1 = make<BinaryOp>(Operations::Mult, Constant::int32(-1), Constant::int32(-2));
+ ASSERT_EQ(constEval(tree1)->getValueInt32(), 2);
+ // 1 * -1
+ ABT tree2 = make<BinaryOp>(Operations::Mult, Constant::int32(1), Constant::int32(-1));
+ ASSERT_EQ(constEval(tree2)->getValueInt32(), -1);
+ // 2 * INT32_MAX
+ ABT tree3 = make<BinaryOp>(Operations::Mult, Constant::int32(2), Constant::int32(INT32_MAX));
+ ASSERT_EQ(constEval(tree3)->getValueInt64(), 4294967294);
+}
+
+
+TEST(Optimizer, IntegerOverflow) {
+ auto int32tree =
+ make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::int32(1));
+ ASSERT_EQ(constEval(int32tree)->getValueInt64(), 2147483648);
+}
+
+TEST(Optimizer, IntegerUnderflow) {
+ auto int32tree =
+ make<BinaryOp>(Operations::Add, Constant::int32(INT32_MIN), Constant::int32(-1));
+ ASSERT_EQ(constEval(int32tree)->getValueInt64(), -2147483649);
+
+ auto tree =
+ make<BinaryOp>(Operations::Add, Constant::int32(INT32_MAX), Constant::int64(INT64_MIN));
+ ASSERT_EQ(constEval(tree)->getValueInt64(), -9223372034707292161);
}
TEST(Optimizer, Tracker1) {
diff --git a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp
index 0278e20700e..86ae4add4d6 100644
--- a/src/mongo/db/query/optimizer/rewrites/const_eval.cpp
+++ b/src/mongo/db/query/optimizer/rewrites/const_eval.cpp
@@ -28,7 +28,15 @@
*/
#include "mongo/db/query/optimizer/rewrites/const_eval.h"
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/db/exec/sbe/vm/vm.h"
#include "mongo/db/query/optimizer/utils/utils.h"
+#include "mongo/platform/decimal128.h"
+#include "mongo/util/assert_util.h"
+#include <bits/floatn-common.h>
+#include <cfloat>
+#include <climits>
+#include <cstdint>
namespace mongo::optimizer {
bool ConstEval::optimize(ABT& n) {
@@ -191,24 +199,83 @@ namespace fold_helpers {
using namespace sbe::value;
template <class T>
-sbe::value::Value constFoldNumberHelper(const sbe::value::TypeTags lhsTag,
- const sbe::value::Value lhsValue,
- const TypeTags rhsTag,
- const sbe::value::Value rhsValue) {
+sbe::value::Value constFoldNumberAdd(const sbe::value::TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
const auto result = numericCast<T>(lhsTag, lhsValue) + numericCast<T>(rhsTag, rhsValue);
return bitcastFrom<T>(result);
}
template <>
-sbe::value::Value constFoldNumberHelper<Decimal128>(const TypeTags lhsTag,
- const sbe::value::Value lhsValue,
- const TypeTags rhsTag,
- const sbe::value::Value rhsValue) {
- const auto result =
+sbe::value::Value constFoldNumberAdd<Decimal128>(const TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const Decimal128 result =
numericCast<Decimal128>(lhsTag, lhsValue).add(numericCast<Decimal128>(rhsTag, rhsValue));
return makeCopyDecimal(result).second;
}
+template <class T>
+sbe::value::Value constFoldNumberSubtract(const sbe::value::TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const auto result = numericCast<T>(lhsTag, lhsValue) - numericCast<T>(rhsTag, rhsValue);
+ return bitcastFrom<T>(result);
+}
+
+template <>
+sbe::value::Value constFoldNumberSubtract<Decimal128>(const TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const Decimal128 result = numericCast<Decimal128>(lhsTag, lhsValue)
+ .subtract(numericCast<Decimal128>(rhsTag, rhsValue));
+ return makeCopyDecimal(result).second;
+}
+
+template <class T>
+sbe::value::Value constFoldNumberMult(const sbe::value::TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const auto result = numericCast<T>(lhsTag, lhsValue) * numericCast<T>(rhsTag, rhsValue);
+ return bitcastFrom<T>(result);
+}
+
+template <>
+sbe::value::Value constFoldNumberMult<Decimal128>(const TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue) {
+ const Decimal128 result = numericCast<Decimal128>(lhsTag, lhsValue)
+ .multiply(numericCast<Decimal128>(rhsTag, rhsValue));
+ return makeCopyDecimal(result).second;
+}
+
+// Checks for Integer Overflow and Underflow
+bool willOverflow(const TypeTags lhsTag,
+ const sbe::value::Value lhsValue,
+ const TypeTags rhsTag,
+ const sbe::value::Value rhsValue,
+ const TypeTags resultType) {
+ if (resultType == TypeTags::NumberInt32) {
+ int castedLHS = numericCast<int32_t>(resultType, lhsValue);
+ int castedRHS = numericCast<int32_t>(resultType, rhsValue);
+ return (castedLHS > 0 && castedRHS > (INT32_MAX - castedLHS)) ||
+ (castedLHS < 0 && castedRHS < (INT32_MIN - castedLHS));
+ } else if (resultType == TypeTags::NumberInt64) {
+ long castedLHS = numericCast<int64_t>(resultType, lhsValue);
+ long castedRHS = numericCast<int64_t>(resultType, rhsValue);
+ return (castedLHS > 0 && castedRHS > (INT64_MAX - castedLHS)) ||
+ (castedLHS < 0 && castedRHS < (INT64_MIN - castedLHS));
+ } else {
+ MONGO_UNREACHABLE;
+ }
+}
+
} // namespace fold_helpers
// Specific transport for binary operation
@@ -221,52 +288,46 @@ void ConstEval::transport(ABT& n, const BinaryOp& op, ABT& lhs, ABT& rhs) {
case Operations::Add: {
// Let say we want to recognize ConstLhs + ConstRhs and replace it with the result of
// addition.
+ Constant* lhsConst = lhs.cast<Constant>();
+ Constant* rhsConst = rhs.cast<Constant>();
+ if (lhsConst && rhsConst) {
+ auto [lhsTag, lhsValue] = lhsConst->get();
+ auto [rhsTag, rhsValue] = rhsConst->get();
+ auto [_, resultType, resultValue] =
+ sbe::vm::ByteCode::genericAdd(lhsTag, lhsValue, rhsTag, rhsValue);
+ swapAndUpdate(n, make<Constant>(resultType, resultValue));
+ }
+ break;
+ }
+
+ case Operations::Sub: {
+ // Let say we want to recognize ConstLhs - ConstRhs and replace it with the result of
+ // subtraction.
auto lhsConst = lhs.cast<Constant>();
auto rhsConst = rhs.cast<Constant>();
if (lhsConst && rhsConst) {
auto [lhsTag, lhsValue] = lhsConst->get();
auto [rhsTag, rhsValue] = rhsConst->get();
+ auto [_, resultType, resultValue] =
+ sbe::vm::ByteCode::genericSub(lhsTag, lhsValue, rhsTag, rhsValue);
+ swapAndUpdate(n, make<Constant>(resultType, resultValue));
+ }
+ break;
+ }
- if (isNumber(lhsTag) && isNumber(rhsTag)) {
- // So this is the addition operation and both arguments are number constants,
- // hence we can compute the result.
-
- const TypeTags resultType = getWidestNumericalType(lhsTag, rhsTag);
- sbe::value::Value resultValue;
-
- switch (resultType) {
- case TypeTags::NumberInt32: {
- resultValue =
- constFoldNumberHelper<int32_t>(lhsTag, lhsValue, rhsTag, rhsValue);
- break;
- }
-
- case TypeTags::NumberInt64: {
- resultValue =
- constFoldNumberHelper<int64_t>(lhsTag, lhsValue, rhsTag, rhsValue);
- break;
- }
-
- case TypeTags::NumberDouble: {
- resultValue =
- constFoldNumberHelper<double>(lhsTag, lhsValue, rhsTag, rhsValue);
- break;
- }
-
- case TypeTags::NumberDecimal: {
- resultValue = constFoldNumberHelper<Decimal128>(
- lhsTag, lhsValue, rhsTag, rhsValue);
- break;
- }
-
- default:
- MONGO_UNREACHABLE;
- }
+ case Operations::Mult: {
+ // Let say we want to recognize ConstLhs * ConstRhs and replace it with the result of
+ // multiplication.
+ auto lhsConst = lhs.cast<Constant>();
+ auto rhsConst = rhs.cast<Constant>();
- // And this is the crucial step - we swap the current node (n) for the result.
- swapAndUpdate(n, make<Constant>(resultType, resultValue));
- }
+ if (lhsConst && rhsConst) {
+ auto [lhsTag, lhsValue] = lhsConst->get();
+ auto [rhsTag, rhsValue] = rhsConst->get();
+ auto [_, resultType, resultValue] =
+ sbe::vm::ByteCode::genericMul(lhsTag, lhsValue, rhsTag, rhsValue);
+ swapAndUpdate(n, make<Constant>(resultType, resultValue));
}
break;
}
diff --git a/src/mongo/db/query/optimizer/syntax/expr.cpp b/src/mongo/db/query/optimizer/syntax/expr.cpp
index abfcbb6d5f2..dd7e0258816 100644
--- a/src/mongo/db/query/optimizer/syntax/expr.cpp
+++ b/src/mongo/db/query/optimizer/syntax/expr.cpp
@@ -29,6 +29,7 @@
#include "mongo/db/query/optimizer/syntax/expr.h"
#include "mongo/db/query/optimizer/node.h"
+#include "mongo/platform/decimal128.h"
namespace mongo::optimizer {
@@ -141,4 +142,23 @@ int32_t Constant::getValueInt32() const {
return bitcastTo<int32_t>(_val);
}
+bool Constant::isValueDouble() const {
+ return _tag == TypeTags::NumberDouble;
+}
+
+double Constant::getValueDouble() const {
+ uassert(673180, "Constant value type is not double", isValueDouble());
+ return bitcastTo<double>(_val);
+}
+
+bool Constant::isValueDecimal() const {
+ return _tag == TypeTags::NumberDecimal;
+}
+
+Decimal128 Constant::getValueDecimal() const {
+ uassert(673181, "Constant value type is not Decimal128", isValueDecimal());
+ return bitcastTo<Decimal128>(_val);
+}
+
+
} // namespace mongo::optimizer
diff --git a/src/mongo/db/query/optimizer/syntax/expr.h b/src/mongo/db/query/optimizer/syntax/expr.h
index 652fbf1c3fb..5fb70be4445 100644
--- a/src/mongo/db/query/optimizer/syntax/expr.h
+++ b/src/mongo/db/query/optimizer/syntax/expr.h
@@ -29,6 +29,7 @@
#pragma once
+#include "mongo/platform/decimal128.h"
#include <ostream>
#include "mongo/db/exec/sbe/values/value.h"
@@ -81,6 +82,12 @@ public:
bool isValueInt32() const;
int32_t getValueInt32() const;
+ bool isValueDouble() const;
+ double getValueDouble() const;
+
+ bool isValueDecimal() const;
+ Decimal128 getValueDecimal() const;
+
bool isNumber() const {
return sbe::value::isNumber(_tag);
}