summaryrefslogtreecommitdiff
path: root/src/mongo/db/exec
diff options
context:
space:
mode:
authorMatt Boros <matt.boros@mongodb.com>2023-03-16 19:52:27 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2023-03-16 21:03:35 +0000
commitb06f88e1d2d4d9f96d9237597a1937f13faf9dae (patch)
treeb57a72a9050762de164466ec0d3e9fb1be06345d /src/mongo/db/exec
parent12d6a66e432c74defdc91417e2c00a5251a36d5e (diff)
downloadmongo-b06f88e1d2d4d9f96d9237597a1937f13faf9dae.tar.gz
SERVER-51553 Support expression $round in SBE
Diffstat (limited to 'src/mongo/db/exec')
-rw-r--r--src/mongo/db/exec/sbe/SConscript1
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp2
-rw-r--r--src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp198
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp102
4 files changed, 274 insertions, 29 deletions
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 4798734ad48..bef7254de27 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -198,6 +198,7 @@ env.CppUnitTest(
'expressions/sbe_regex_test.cpp',
'expressions/sbe_replace_one_expression_test.cpp',
'expressions/sbe_reverse_array_builtin_test.cpp',
+ 'expressions/sbe_round_builtin_test.cpp',
'expressions/sbe_runtime_environment_test.cpp',
'expressions/sbe_set_expressions_test.cpp',
'expressions/sbe_shard_filter_builtin_test.cpp',
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index 6a8dacec67a..a6a1e2ae524 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -682,7 +682,7 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
{"sinh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::sinh, false}},
{"tan", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tan, false}},
{"tanh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tanh, false}},
- {"round", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::round, false}},
+ {"round", BuiltinFn{[](size_t n) { return n == 1 || n == 2; }, vm::Builtin::round, false}},
{"concat", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concat, false}},
{"concatArrays", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concatArrays, false}},
{"aggConcatArraysCapped",
diff --git a/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp
new file mode 100644
index 00000000000..823dcdf7c9a
--- /dev/null
+++ b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp
@@ -0,0 +1,198 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/base/string_data.h"
+#include "mongo/db/exec/sbe/expression_test_base.h"
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/unittest/assert.h"
+
+namespace mongo::sbe {
+namespace {
+using namespace value;
+
+static std::pair<TypeTags, Value> makeDecimal(const std::string& n) {
+ return makeCopyDecimal(Decimal128(n));
+}
+
+const std::pair<TypeTags, Value> kNull{TypeTags::Null, 0};
+const std::pair<TypeTags, Value> kNothing{TypeTags::Nothing, 0};
+
+/**
+ * A test for SBE built-in function round with one argument. The "place" argument defaults to 0.
+ */
+TEST_F(EExpressionTestFixture, RoundOneArg) {
+ OwnedValueAccessor numAccessor;
+ auto numSlot = bindAccessor(&numAccessor);
+
+ // Construct an invocation of round function.
+ auto roundExpression =
+ sbe::makeE<sbe::EFunction>("round", sbe::makeEs(makeE<EVariable>(numSlot)));
+ auto compiledRound = compileExpression(*roundExpression);
+
+ struct TestCase {
+ std::pair<TypeTags, Value> num;
+ std::pair<TypeTags, Value> result;
+ };
+
+ std::vector<TestCase> testCases = {
+ {makeInt32(0), makeInt32(0)},
+ {makeInt32(2), makeInt32(2)},
+ {makeInt64(4), makeInt64(4)},
+ {makeDouble(-2.0), makeDouble(-2.0)},
+ {makeDouble(0.9), makeDouble(1.0)},
+ {makeDouble(-1.2), makeDouble(-1.0)},
+ {makeDecimal("-1.6"), makeDecimal("-2")},
+ {makeDouble(1.298), makeDouble(1.0)},
+ {makeDecimal("1.298"), makeDecimal("1")},
+ // Infinity cases.
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ // Decimal128 infinity.
+ {makeDecimal("Infinity"), makeDecimal("Infinity")},
+ {makeDecimal("-Infinity"), makeDecimal("-Infinity")},
+ // NaN cases.
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDecimal("NaN"), makeDecimal("NaN")},
+ // Null case.
+ {kNull, kNothing},
+ // Nothing case.
+ {kNothing, kNothing},
+ };
+
+ int testNum = 0;
+ for (auto&& testCase : testCases) {
+ numAccessor.reset(testCase.num.first, testCase.num.second);
+ ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second);
+
+ // Execute the round function.
+ auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get());
+ ValueGuard actualResultGuard(resultTag, resultValue);
+
+ auto [compTag, compVal] =
+ compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second);
+ ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum;
+ ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum;
+
+ testNum++;
+ }
+}
+
+/**
+ * A test for SBE built-in function round with two arguments.
+ */
+TEST_F(EExpressionTestFixture, RoundTwoArgs) {
+ OwnedValueAccessor numAccessor;
+ auto numSlot = bindAccessor(&numAccessor);
+ OwnedValueAccessor placeAccessor;
+ auto placeSlot = bindAccessor(&placeAccessor);
+
+ // Construct an invocation of round function.
+ auto roundExpression = sbe::makeE<sbe::EFunction>(
+ "round", sbe::makeEs(makeE<EVariable>(numSlot), makeE<EVariable>(placeSlot)));
+ auto compiledRound = compileExpression(*roundExpression);
+
+ struct TestCase {
+ std::pair<TypeTags, Value> num;
+ std::pair<TypeTags, Value> place;
+ std::pair<TypeTags, Value> result;
+ };
+
+ std::vector<TestCase> testCases = {
+ {makeInt32(43), makeInt32(-1), makeInt32(40)},
+ // Try rounding with different types for the "place" argument.
+ {makeDouble(1.298), makeInt32(0), makeDouble(1.0)},
+ {makeDouble(1.298), makeInt64(0ull), makeDouble(1.0)},
+ {makeDouble(1.298), makeDouble(0.0), makeDouble(1.0)},
+ // Try rounding with a different value for the "place" argument.
+ {makeDouble(1.298), makeDouble(1.0), makeDouble(1.3)},
+ {makeDouble(23.298), makeDouble(-1.0), makeDouble(20.0)},
+ // Decimal tests.
+ {makeDecimal("1.298"), makeDouble(0.0), makeDecimal("1")},
+ {makeDecimal("1.298"), makeDouble(1.0), makeDecimal("1.3")},
+ {makeDecimal("23.298"), makeDouble(-1.0), makeDecimal("20.0")},
+ {makeDecimal("1.298912343250054252245154325"),
+ makeDouble(20.0),
+ makeDecimal("1.29891234325005425225")},
+ {makeDecimal("1.298"), makeDouble(100.0), makeDecimal("1.298")},
+ // Integer promotion case.
+ {makeInt32(2147483647), makeDouble(-1), makeInt64(2147483650)},
+ // Infinity cases.
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(10),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(-10),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ // Decimal128 infinity.
+ {makeDecimal("Infinity"), makeDouble(0), makeDecimal("Infinity")},
+ {makeDecimal("-Infinity"), makeDouble(0), makeDecimal("-Infinity")},
+ {makeDecimal("Infinity"), makeDouble(10), makeDecimal("Infinity")},
+ {makeDecimal("Infinity"), makeDouble(-10), makeDecimal("Infinity")},
+ // NaN cases.
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(1),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(-2),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDecimal("NaN"), makeDouble(0), makeDecimal("NaN")},
+ {makeDecimal("NaN"), makeDouble(2), makeDecimal("NaN")},
+ // Null cases.
+ {kNull, makeDouble(-5), kNothing},
+ {kNull, makeDouble(5), kNothing},
+ {makeDouble(1.1), kNull, kNothing},
+ // Nothing cases.
+ {kNothing, makeDouble(-5), kNothing},
+ {kNothing, makeDouble(5), kNothing},
+ {makeDouble(1.1), kNothing, kNothing},
+ // Try the limits of the "place" arg (-20 and 100).
+ {makeDouble(1.298), makeDouble(100), makeDouble(1.298)},
+ {makeDouble(1.298), makeDouble(-20), makeDouble(0)}};
+
+ int testNum = 0;
+ for (auto&& testCase : testCases) {
+ numAccessor.reset(testCase.num.first, testCase.num.second);
+ placeAccessor.reset(testCase.place.first, testCase.place.second);
+ ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second);
+
+ // Execute the round function.
+ auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get());
+ ValueGuard actualResultGuard(resultTag, resultValue);
+
+ auto [compTag, compVal] =
+ compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second);
+ ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum;
+ ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum;
+
+ testNum++;
+ }
+}
+} // namespace
+} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index fd8a9561c2e..88a08c30c6e 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -3639,39 +3639,85 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinTanh(ArityType a
return genericTanh(operandTag, operandValue);
}
-FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) {
- invariant(arity == 1);
- auto [owned, tag, val] = getFromStack(0);
-
- // Round 'val' to the closest integer, with ties rounding to the closest even integer.
- // If 'val' is +Inf, -Inf, or NaN, this function will simply return 'val' as-is.
+/**
+ * Converts a number to int32 assuming the input fits the range. This is used for $round "place"
+ * argument, which is checked to be a whole number between -20 and 100, but could still be a
+ * non-int32 type.
+ */
+static int32_t convertNumericToInt32(const value::TypeTags tag, const value::Value val) {
switch (tag) {
- case value::TypeTags::NumberInt32:
- case value::TypeTags::NumberInt64:
- // The value is already an integer, so just return it as-is.
- return {false, tag, val};
+ case value::TypeTags::NumberInt32: {
+ return value::bitcastTo<int32_t>(val);
+ }
+ case value::TypeTags::NumberInt64: {
+ return static_cast<int32_t>(value::bitcastTo<int64_t>(val));
+ }
case value::TypeTags::NumberDouble: {
- // std::nearbyint()'s behavior relies on a thread-local "rounding mode", so
- // we use boost::numeric::RoundEven<double>::nearbyint() instead. We should
- // switch over to roundeven() once it becomes available in the standard library.
- // (See https://en.cppreference.com/w/c/experimental/fpext1 for details.)
- auto operand = value::bitcastTo<double>(val);
- auto rounded = boost::numeric::RoundEven<double>::nearbyint(operand);
- return {false, tag, value::bitcastFrom<double>(rounded)};
+ return static_cast<int32_t>(value::bitcastTo<double>(val));
}
case value::TypeTags::NumberDecimal: {
- auto operand = value::bitcastTo<Decimal128>(val);
- auto rounded = operand.round(Decimal128::RoundingMode::kRoundTiesToEven);
- if (operand.isEqual(rounded)) {
- // If the output of rounding is equal to the input, then we can just take
- // ownership of 'operand' and return it. (This is more efficient than calling
- // makeCopyDecimal(), which would allocate memory on the heap.)
- topStack(false, value::TypeTags::Nothing, 0);
- return {owned, tag, val};
- }
+ Decimal128 dec = value::bitcastTo<Decimal128>(val);
+ return dec.toInt(Decimal128::kRoundTiesToEven);
+ }
+ default:
+ MONGO_UNREACHABLE;
+ }
+}
- auto [tag, val] = value::makeCopyDecimal(rounded);
- return {true, tag, val};
+FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) {
+ invariant(arity == 1 || arity == 2);
+ int32_t place = 0;
+ const auto [numOwn, numTag, numVal] = getFromStack(0);
+ if (arity == 2) {
+ const auto [placeOwn, placeTag, placeVal] = getFromStack(1);
+ if (!value::isNumber(placeTag)) {
+ return {false, value::TypeTags::Nothing, 0};
+ }
+ place = convertNumericToInt32(placeTag, placeVal);
+ }
+
+ // Construct 10^-precisionValue, which will be used as the quantize reference. This is passed to
+ // decimal.quantize() to indicate the precision of our rounding.
+ const auto quantum = Decimal128(0LL, Decimal128::kExponentBias - place, 0LL, 1LL);
+ switch (numTag) {
+ case value::TypeTags::NumberDecimal: {
+ auto dec = value::bitcastTo<Decimal128>(numVal);
+ if (!dec.isInfinite()) {
+ dec = dec.quantize(quantum, Decimal128::kRoundTiesToEven);
+ }
+ auto [resultTag, resultValue] = value::makeCopyDecimal(dec);
+ return {true, resultTag, resultValue};
+ }
+ case value::TypeTags::NumberDouble: {
+ auto asDec = Decimal128(value::bitcastTo<double>(numVal), Decimal128::kRoundTo34Digits);
+ if (!asDec.isInfinite()) {
+ asDec = asDec.quantize(quantum, Decimal128::kRoundTiesToEven);
+ }
+ return {
+ false, value::TypeTags::NumberDouble, value::bitcastFrom<double>(asDec.toDouble())};
+ }
+ case value::TypeTags::NumberInt32:
+ case value::TypeTags::NumberInt64: {
+ if (place >= 0) {
+ return {numOwn, numTag, numVal};
+ }
+ auto numericArgll = numTag == value::TypeTags::NumberInt32
+ ? static_cast<int64_t>(value::bitcastTo<int32_t>(numVal))
+ : value::bitcastTo<int64_t>(numVal);
+ auto out = Decimal128(numericArgll).quantize(quantum, Decimal128::kRoundTiesToEven);
+ uint32_t flags = 0;
+ auto outll = out.toLong(&flags);
+ uassert(5155302,
+ "Invalid conversion to long during $round.",
+ !Decimal128::hasFlag(flags, Decimal128::kInvalid));
+ if (numTag == value::TypeTags::NumberInt64 ||
+ outll > std::numeric_limits<int32_t>::max()) {
+ // Even if the original was an int to begin with - it has to be a long now.
+ return {false, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(outll)};
+ }
+ return {false,
+ value::TypeTags::NumberInt32,
+ value::bitcastFrom<int32_t>(static_cast<int32_t>(outll))};
}
default:
return {false, value::TypeTags::Nothing, 0};