summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--jstests/aggregation/expressions/round_trunc.js146
-rw-r--r--jstests/libs/sbe_assert_error_override.js6
-rw-r--r--src/mongo/db/exec/sbe/SConscript1
-rw-r--r--src/mongo/db/exec/sbe/expressions/expression.cpp2
-rw-r--r--src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp198
-rw-r--r--src/mongo/db/exec/sbe/vm/vm.cpp102
-rw-r--r--src/mongo/db/pipeline/expression.h4
-rw-r--r--src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp40
-rw-r--r--src/mongo/db/query/sbe_stage_builder_abt_helpers.h5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_expression.cpp37
10 files changed, 440 insertions, 101 deletions
diff --git a/jstests/aggregation/expressions/round_trunc.js b/jstests/aggregation/expressions/round_trunc.js
index 1e03eb47e68..97e20cd222f 100644
--- a/jstests/aggregation/expressions/round_trunc.js
+++ b/jstests/aggregation/expressions/round_trunc.js
@@ -5,66 +5,81 @@
// For assertErrorCode.
load("jstests/aggregation/extras/utils.js");
+load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs.
-var coll = db.server19548;
+const coll = db.server19548;
coll.drop();
-// Seed collection so that the pipeline will execute.
-assert.commandWorked(coll.insert({}));
// Helper for testing that op returns expResult.
-function testOp(op, expResult) {
- var pipeline = [{$project: {_id: 0, result: op}}];
+function testOp(exprName, value, expResult, place) {
+ coll.drop();
+ assert.commandWorked(coll.insert({a: value}));
+ const project = place === undefined ? {[exprName]: "$a"} : {[exprName]: ["$a", place]};
+ const pipeline = [{$project: {_id: 0, result: project}}];
assert.eq(coll.aggregate(pipeline).toArray(), [{result: expResult}]);
}
+function testRound(value, expResult, place) {
+ testOp("$round", value, expResult, place);
+}
+
+function testTrunc(value, expResult, place) {
+ testOp("$trunc", value, expResult, place);
+}
+
// Test $trunc and $round with one argument.
-testOp({$trunc: NumberLong(4)}, NumberLong(4));
-testOp({$trunc: NaN}, NaN);
-testOp({$trunc: Infinity}, Infinity);
-testOp({$trunc: -Infinity}, -Infinity);
-testOp({$trunc: null}, null);
-testOp({$trunc: -2.0}, -2.0);
-testOp({$trunc: 0.9}, 0.0);
-testOp({$trunc: -1.2}, -1.0);
-testOp({$trunc: NumberDecimal("-1.6")}, NumberDecimal("-1"));
-
-testOp({$round: NumberLong(4)}, NumberLong(4));
-testOp({$round: NaN}, NaN);
-testOp({$round: Infinity}, Infinity);
-testOp({$round: -Infinity}, -Infinity);
-testOp({$round: null}, null);
-testOp({$round: -2.0}, -2.0);
-testOp({$round: 0.9}, 1.0);
-testOp({$round: -1.2}, -1.0);
-testOp({$round: NumberDecimal("-1.6")}, NumberDecimal("-2"));
+testTrunc(NumberLong(4), NumberLong(4));
+testTrunc(NumberLong(4), NumberLong(4));
+testTrunc(NaN, NaN);
+testTrunc(Infinity, Infinity);
+testTrunc(-Infinity, -Infinity);
+testTrunc(null, null);
+testTrunc(-2.0, -2.0);
+testTrunc(0.9, 0.0);
+testTrunc(-1.2, -1.0);
+testTrunc(NumberDecimal("-1.6"), NumberDecimal("-1"));
+
+testRound(NumberLong(4), NumberLong(4));
+testRound(NaN, NaN);
+testRound(Infinity, Infinity);
+testRound(-Infinity, -Infinity);
+testRound(null, null);
+testRound(-2.0, -2.0);
+testRound(0.9, 1.0);
+testRound(-1.2, -1.0);
+testRound(NumberDecimal("-1.6"), NumberDecimal("-2"));
// Test $trunc and $round with two arguments.
-testOp({$trunc: [1.298, 0]}, 1);
-testOp({$trunc: [1.298, 1]}, 1.2);
-testOp({$trunc: [23.298, -1]}, 20);
-testOp({$trunc: [NumberDecimal("1.298"), 0]}, NumberDecimal("1"));
-testOp({$trunc: [NumberDecimal("1.298"), 1]}, NumberDecimal("1.2"));
-testOp({$trunc: [NumberDecimal("23.298"), -1]}, NumberDecimal("2E+1"));
-testOp({$trunc: [1.298, 100]}, 1.298);
-testOp({$trunc: [NumberDecimal("1.298912343250054252245154325"), NumberLong("20")]},
- NumberDecimal("1.29891234325005425224"));
-testOp({$trunc: [NumberDecimal("1.298"), NumberDecimal("100")]},
- NumberDecimal("1.298000000000000000000000000000000"));
-
-testOp({$round: [1.298, 0]}, 1);
-testOp({$round: [1.298, 1]}, 1.3);
-testOp({$round: [23.298, -1]}, 20);
-testOp({$round: [NumberDecimal("1.298"), 0]}, NumberDecimal("1"));
-testOp({$round: [NumberDecimal("1.298"), 1]}, NumberDecimal("1.3"));
-testOp({$round: [NumberDecimal("23.298"), -1]}, NumberDecimal("2E+1"));
-testOp({$round: [1.298, 100]}, 1.298);
-testOp({$round: [NumberDecimal("1.298912343250054252245154325"), NumberLong("20")]},
- NumberDecimal("1.29891234325005425225"));
-testOp({$round: [NumberDecimal("1.298"), NumberDecimal("100")]},
- NumberDecimal("1.298000000000000000000000000000000"));
+testTrunc(1.298, 1, 0);
+testTrunc(1.298, 1.2, 1);
+testTrunc(23.298, 20, -1);
+testTrunc(NumberDecimal("1.298"), NumberDecimal("1"), 0);
+testTrunc(NumberDecimal("1.298"), NumberDecimal("1.2"), 1);
+testTrunc(NumberDecimal("23.298"), NumberDecimal("2E+1"), -1);
+testTrunc(1.298, 1.298, 100);
+testTrunc(NumberDecimal("1.298912343250054252245154325"),
+ NumberDecimal("1.29891234325005425224"),
+ NumberLong("20"));
+testTrunc(NumberDecimal("1.298"),
+ NumberDecimal("1.298000000000000000000000000000000"),
+ NumberDecimal("100"));
+
+testRound(1.298, 1, 0);
+testRound(1.298, 1.3, 1);
+testRound(23.298, 20, -1);
+testRound(NumberDecimal("1.298"), NumberDecimal("1"), 0);
+testRound(NumberDecimal("1.298"), NumberDecimal("1.3"), 1);
+testRound(NumberDecimal("23.298"), NumberDecimal("2E+1"), -1);
+testRound(1.298, 1.298, 100);
+testRound(NumberDecimal("1.298912343250054252245154325"),
+ NumberDecimal("1.29891234325005425225"),
+ NumberLong("20"));
+testRound(NumberDecimal("1.298"),
+ NumberDecimal("1.298000000000000000000000000000000"),
+ NumberDecimal("100"));
// Test $round overflow.
-testOp({$round: [NumberInt("2147483647"), -1]}, NumberLong("2147483650"));
+testRound(NumberInt("2147483647"), NumberLong("2147483650"), -1);
assertErrorCode(coll, [{$project: {a: {$round: [NumberLong("9223372036854775806"), -1]}}}], 51080);
// Test $trunc and $round with more than 2 arguments.
@@ -76,21 +91,24 @@ assertErrorCode(coll, [{$project: {a: {$round: "string"}}}], 51081);
assertErrorCode(coll, [{$project: {a: {$trunc: "string"}}}], 51081);
// Test NaN and Infinity numeric args.
-testOp({$round: [Infinity, 0]}, Infinity);
-testOp({$round: [-Infinity, 0]}, -Infinity);
-testOp({$round: [NaN, 0]}, NaN);
-testOp({$round: [NumberDecimal("Infinity"), 0]}, NumberDecimal("Infinity"));
-testOp({$round: [NumberDecimal("-Infinity"), 0]}, NumberDecimal("-Infinity"));
-testOp({$round: [NumberDecimal("NaN"), 0]}, NumberDecimal("NaN"));
-
-testOp({$trunc: [Infinity, 0]}, Infinity);
-testOp({$trunc: [-Infinity, 0]}, -Infinity);
-testOp({$trunc: [NaN, 0]}, NaN);
-testOp({$trunc: [NumberDecimal("Infinity"), 0]}, NumberDecimal("Infinity"));
-testOp({$trunc: [NumberDecimal("-Infinity"), 0]}, NumberDecimal("-Infinity"));
-testOp({$trunc: [NumberDecimal("NaN"), 0]}, NumberDecimal("NaN"));
+testRound(Infinity, Infinity, 0);
+testRound(-Infinity, -Infinity, 0);
+testRound(NaN, NaN, 0);
+testRound(NumberDecimal("Infinity"), NumberDecimal("Infinity"), 0);
+testRound(NumberDecimal("-Infinity"), NumberDecimal("-Infinity"), 0);
+testRound(NumberDecimal("NaN"), NumberDecimal("NaN"), 0);
+testRound(null, null, 1);
+testRound(1, null, null);
+
+testTrunc(Infinity, Infinity, 0);
+testTrunc(-Infinity, -Infinity, 0);
+testTrunc(NaN, NaN, 0);
+testTrunc(NumberDecimal("Infinity"), NumberDecimal("Infinity"), 0);
+testTrunc(NumberDecimal("-Infinity"), NumberDecimal("-Infinity"), 0);
+testTrunc(NumberDecimal("NaN"), NumberDecimal("NaN"), 0);
// Test precision arguments that are out of bounds.
+assert.commandWorked(coll.insert({}));
assertErrorCode(coll, [{$project: {a: {$round: [1, NumberLong("101")]}}}], 51083);
assertErrorCode(coll, [{$project: {a: {$round: [1, NumberLong("-21")]}}}], 51083);
assertErrorCode(coll, [{$project: {a: {$round: [1, NumberDecimal("101")]}}}], 51083);
@@ -111,4 +129,10 @@ assertErrorCode(coll, [{$project: {a: {$trunc: [1, -21]}}}], 51083);
// Test non-integral precision arguments.
assertErrorCode(coll, [{$project: {a: {$round: [1, NumberDecimal("1.4")]}}}], 51082);
assertErrorCode(coll, [{$project: {a: {$trunc: [1, 10.5]}}}], 51082);
+assertErrorCode(coll, [{$project: {a: {$round: [0, NaN]}}}], 31109);
+assertErrorCode(coll, [{$project: {a: {$round: [0, NumberDecimal("NaN")]}}}], 51082);
+assertErrorCode(coll, [{$project: {a: {$round: [BinData(0, ""), 0]}}}], 51081);
+assertErrorCode(coll, [{$project: {a: {$round: [0, BinData(0, "")]}}}], 16004);
+assertErrorCode(coll, [{$project: {a: {$round: MinKey}}}], 51081);
+assertErrorCode(coll, [{$project: {a: {$round: MaxKey}}}], 51081);
}());
diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js
index eafabc51ccf..1cc84cae6b3 100644
--- a/jstests/libs/sbe_assert_error_override.js
+++ b/jstests/libs/sbe_assert_error_override.js
@@ -138,6 +138,12 @@ const equivalentErrorCodesList = [
[17047, 5126900, 7158100],
[17048, 5126900, 7158100],
[17049, 5126900, 7158100],
+ [51081, 5155300],
+ [51080, 5155302],
+ [31109, 5155301],
+ [51082, 5155301],
+ [51083, 5155301],
+ [16004, 5155301]
];
// This map is generated based on the contents of 'equivalentErrorCodesList'. This map should _not_
diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript
index 4798734ad48..bef7254de27 100644
--- a/src/mongo/db/exec/sbe/SConscript
+++ b/src/mongo/db/exec/sbe/SConscript
@@ -198,6 +198,7 @@ env.CppUnitTest(
'expressions/sbe_regex_test.cpp',
'expressions/sbe_replace_one_expression_test.cpp',
'expressions/sbe_reverse_array_builtin_test.cpp',
+ 'expressions/sbe_round_builtin_test.cpp',
'expressions/sbe_runtime_environment_test.cpp',
'expressions/sbe_set_expressions_test.cpp',
'expressions/sbe_shard_filter_builtin_test.cpp',
diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp
index 6a8dacec67a..a6a1e2ae524 100644
--- a/src/mongo/db/exec/sbe/expressions/expression.cpp
+++ b/src/mongo/db/exec/sbe/expressions/expression.cpp
@@ -682,7 +682,7 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = {
{"sinh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::sinh, false}},
{"tan", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tan, false}},
{"tanh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tanh, false}},
- {"round", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::round, false}},
+ {"round", BuiltinFn{[](size_t n) { return n == 1 || n == 2; }, vm::Builtin::round, false}},
{"concat", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concat, false}},
{"concatArrays", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concatArrays, false}},
{"aggConcatArraysCapped",
diff --git a/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp
new file mode 100644
index 00000000000..823dcdf7c9a
--- /dev/null
+++ b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp
@@ -0,0 +1,198 @@
+/**
+ * Copyright (C) 2021-present MongoDB, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the Server Side Public License, version 1,
+ * as published by MongoDB, Inc.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * Server Side Public License for more details.
+ *
+ * You should have received a copy of the Server Side Public License
+ * along with this program. If not, see
+ * <http://www.mongodb.com/licensing/server-side-public-license>.
+ *
+ * As a special exception, the copyright holders give permission to link the
+ * code of portions of this program with the OpenSSL library under certain
+ * conditions as described in each individual source file and distribute
+ * linked combinations including the program with the OpenSSL library. You
+ * must comply with the Server Side Public License in all respects for
+ * all of the code used other than as permitted herein. If you modify file(s)
+ * with this exception, you may extend this exception to your version of the
+ * file(s), but you are not obligated to do so. If you do not wish to do so,
+ * delete this exception statement from your version. If you delete this
+ * exception statement from all source files in the program, then also delete
+ * it in the license file.
+ */
+
+#include "mongo/base/string_data.h"
+#include "mongo/db/exec/sbe/expression_test_base.h"
+#include "mongo/db/exec/sbe/values/value.h"
+#include "mongo/unittest/assert.h"
+
+namespace mongo::sbe {
+namespace {
+using namespace value;
+
+static std::pair<TypeTags, Value> makeDecimal(const std::string& n) {
+ return makeCopyDecimal(Decimal128(n));
+}
+
+const std::pair<TypeTags, Value> kNull{TypeTags::Null, 0};
+const std::pair<TypeTags, Value> kNothing{TypeTags::Nothing, 0};
+
+/**
+ * A test for SBE built-in function round with one argument. The "place" argument defaults to 0.
+ */
+TEST_F(EExpressionTestFixture, RoundOneArg) {
+ OwnedValueAccessor numAccessor;
+ auto numSlot = bindAccessor(&numAccessor);
+
+ // Construct an invocation of round function.
+ auto roundExpression =
+ sbe::makeE<sbe::EFunction>("round", sbe::makeEs(makeE<EVariable>(numSlot)));
+ auto compiledRound = compileExpression(*roundExpression);
+
+ struct TestCase {
+ std::pair<TypeTags, Value> num;
+ std::pair<TypeTags, Value> result;
+ };
+
+ std::vector<TestCase> testCases = {
+ {makeInt32(0), makeInt32(0)},
+ {makeInt32(2), makeInt32(2)},
+ {makeInt64(4), makeInt64(4)},
+ {makeDouble(-2.0), makeDouble(-2.0)},
+ {makeDouble(0.9), makeDouble(1.0)},
+ {makeDouble(-1.2), makeDouble(-1.0)},
+ {makeDecimal("-1.6"), makeDecimal("-2")},
+ {makeDouble(1.298), makeDouble(1.0)},
+ {makeDecimal("1.298"), makeDecimal("1")},
+ // Infinity cases.
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ // Decimal128 infinity.
+ {makeDecimal("Infinity"), makeDecimal("Infinity")},
+ {makeDecimal("-Infinity"), makeDecimal("-Infinity")},
+ // NaN cases.
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDecimal("NaN"), makeDecimal("NaN")},
+ // Null case.
+ {kNull, kNothing},
+ // Nothing case.
+ {kNothing, kNothing},
+ };
+
+ int testNum = 0;
+ for (auto&& testCase : testCases) {
+ numAccessor.reset(testCase.num.first, testCase.num.second);
+ ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second);
+
+ // Execute the round function.
+ auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get());
+ ValueGuard actualResultGuard(resultTag, resultValue);
+
+ auto [compTag, compVal] =
+ compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second);
+ ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum;
+ ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum;
+
+ testNum++;
+ }
+}
+
+/**
+ * A test for SBE built-in function round with two arguments.
+ */
+TEST_F(EExpressionTestFixture, RoundTwoArgs) {
+ OwnedValueAccessor numAccessor;
+ auto numSlot = bindAccessor(&numAccessor);
+ OwnedValueAccessor placeAccessor;
+ auto placeSlot = bindAccessor(&placeAccessor);
+
+ // Construct an invocation of round function.
+ auto roundExpression = sbe::makeE<sbe::EFunction>(
+ "round", sbe::makeEs(makeE<EVariable>(numSlot), makeE<EVariable>(placeSlot)));
+ auto compiledRound = compileExpression(*roundExpression);
+
+ struct TestCase {
+ std::pair<TypeTags, Value> num;
+ std::pair<TypeTags, Value> place;
+ std::pair<TypeTags, Value> result;
+ };
+
+ std::vector<TestCase> testCases = {
+ {makeInt32(43), makeInt32(-1), makeInt32(40)},
+ // Try rounding with different types for the "place" argument.
+ {makeDouble(1.298), makeInt32(0), makeDouble(1.0)},
+ {makeDouble(1.298), makeInt64(0ull), makeDouble(1.0)},
+ {makeDouble(1.298), makeDouble(0.0), makeDouble(1.0)},
+ // Try rounding with a different value for the "place" argument.
+ {makeDouble(1.298), makeDouble(1.0), makeDouble(1.3)},
+ {makeDouble(23.298), makeDouble(-1.0), makeDouble(20.0)},
+ // Decimal tests.
+ {makeDecimal("1.298"), makeDouble(0.0), makeDecimal("1")},
+ {makeDecimal("1.298"), makeDouble(1.0), makeDecimal("1.3")},
+ {makeDecimal("23.298"), makeDouble(-1.0), makeDecimal("20.0")},
+ {makeDecimal("1.298912343250054252245154325"),
+ makeDouble(20.0),
+ makeDecimal("1.29891234325005425225")},
+ {makeDecimal("1.298"), makeDouble(100.0), makeDecimal("1.298")},
+ // Integer promotion case.
+ {makeInt32(2147483647), makeDouble(-1), makeInt64(2147483650)},
+ // Infinity cases.
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(10),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ {makeDouble(std::numeric_limits<double>::infinity()),
+ makeDouble(-10),
+ makeDouble(std::numeric_limits<double>::infinity())},
+ // Decimal128 infinity.
+ {makeDecimal("Infinity"), makeDouble(0), makeDecimal("Infinity")},
+ {makeDecimal("-Infinity"), makeDouble(0), makeDecimal("-Infinity")},
+ {makeDecimal("Infinity"), makeDouble(10), makeDecimal("Infinity")},
+ {makeDecimal("Infinity"), makeDouble(-10), makeDecimal("Infinity")},
+ // NaN cases.
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(1),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDouble(std::numeric_limits<double>::quiet_NaN()),
+ makeDouble(-2),
+ makeDouble(std::numeric_limits<double>::quiet_NaN())},
+ {makeDecimal("NaN"), makeDouble(0), makeDecimal("NaN")},
+ {makeDecimal("NaN"), makeDouble(2), makeDecimal("NaN")},
+ // Null cases.
+ {kNull, makeDouble(-5), kNothing},
+ {kNull, makeDouble(5), kNothing},
+ {makeDouble(1.1), kNull, kNothing},
+ // Nothing cases.
+ {kNothing, makeDouble(-5), kNothing},
+ {kNothing, makeDouble(5), kNothing},
+ {makeDouble(1.1), kNothing, kNothing},
+ // Try the limits of the "place" arg (-20 and 100).
+ {makeDouble(1.298), makeDouble(100), makeDouble(1.298)},
+ {makeDouble(1.298), makeDouble(-20), makeDouble(0)}};
+
+ int testNum = 0;
+ for (auto&& testCase : testCases) {
+ numAccessor.reset(testCase.num.first, testCase.num.second);
+ placeAccessor.reset(testCase.place.first, testCase.place.second);
+ ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second);
+
+ // Execute the round function.
+ auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get());
+ ValueGuard actualResultGuard(resultTag, resultValue);
+
+ auto [compTag, compVal] =
+ compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second);
+ ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum;
+ ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum;
+
+ testNum++;
+ }
+}
+} // namespace
+} // namespace mongo::sbe
diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp
index fd8a9561c2e..88a08c30c6e 100644
--- a/src/mongo/db/exec/sbe/vm/vm.cpp
+++ b/src/mongo/db/exec/sbe/vm/vm.cpp
@@ -3639,39 +3639,85 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinTanh(ArityType a
return genericTanh(operandTag, operandValue);
}
-FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) {
- invariant(arity == 1);
- auto [owned, tag, val] = getFromStack(0);
-
- // Round 'val' to the closest integer, with ties rounding to the closest even integer.
- // If 'val' is +Inf, -Inf, or NaN, this function will simply return 'val' as-is.
+/**
+ * Converts a number to int32 assuming the input fits the range. This is used for $round "place"
+ * argument, which is checked to be a whole number between -20 and 100, but could still be a
+ * non-int32 type.
+ */
+static int32_t convertNumericToInt32(const value::TypeTags tag, const value::Value val) {
switch (tag) {
- case value::TypeTags::NumberInt32:
- case value::TypeTags::NumberInt64:
- // The value is already an integer, so just return it as-is.
- return {false, tag, val};
+ case value::TypeTags::NumberInt32: {
+ return value::bitcastTo<int32_t>(val);
+ }
+ case value::TypeTags::NumberInt64: {
+ return static_cast<int32_t>(value::bitcastTo<int64_t>(val));
+ }
case value::TypeTags::NumberDouble: {
- // std::nearbyint()'s behavior relies on a thread-local "rounding mode", so
- // we use boost::numeric::RoundEven<double>::nearbyint() instead. We should
- // switch over to roundeven() once it becomes available in the standard library.
- // (See https://en.cppreference.com/w/c/experimental/fpext1 for details.)
- auto operand = value::bitcastTo<double>(val);
- auto rounded = boost::numeric::RoundEven<double>::nearbyint(operand);
- return {false, tag, value::bitcastFrom<double>(rounded)};
+ return static_cast<int32_t>(value::bitcastTo<double>(val));
}
case value::TypeTags::NumberDecimal: {
- auto operand = value::bitcastTo<Decimal128>(val);
- auto rounded = operand.round(Decimal128::RoundingMode::kRoundTiesToEven);
- if (operand.isEqual(rounded)) {
- // If the output of rounding is equal to the input, then we can just take
- // ownership of 'operand' and return it. (This is more efficient than calling
- // makeCopyDecimal(), which would allocate memory on the heap.)
- topStack(false, value::TypeTags::Nothing, 0);
- return {owned, tag, val};
- }
+ Decimal128 dec = value::bitcastTo<Decimal128>(val);
+ return dec.toInt(Decimal128::kRoundTiesToEven);
+ }
+ default:
+ MONGO_UNREACHABLE;
+ }
+}
- auto [tag, val] = value::makeCopyDecimal(rounded);
- return {true, tag, val};
+FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) {
+ invariant(arity == 1 || arity == 2);
+ int32_t place = 0;
+ const auto [numOwn, numTag, numVal] = getFromStack(0);
+ if (arity == 2) {
+ const auto [placeOwn, placeTag, placeVal] = getFromStack(1);
+ if (!value::isNumber(placeTag)) {
+ return {false, value::TypeTags::Nothing, 0};
+ }
+ place = convertNumericToInt32(placeTag, placeVal);
+ }
+
+ // Construct 10^-precisionValue, which will be used as the quantize reference. This is passed to
+ // decimal.quantize() to indicate the precision of our rounding.
+ const auto quantum = Decimal128(0LL, Decimal128::kExponentBias - place, 0LL, 1LL);
+ switch (numTag) {
+ case value::TypeTags::NumberDecimal: {
+ auto dec = value::bitcastTo<Decimal128>(numVal);
+ if (!dec.isInfinite()) {
+ dec = dec.quantize(quantum, Decimal128::kRoundTiesToEven);
+ }
+ auto [resultTag, resultValue] = value::makeCopyDecimal(dec);
+ return {true, resultTag, resultValue};
+ }
+ case value::TypeTags::NumberDouble: {
+ auto asDec = Decimal128(value::bitcastTo<double>(numVal), Decimal128::kRoundTo34Digits);
+ if (!asDec.isInfinite()) {
+ asDec = asDec.quantize(quantum, Decimal128::kRoundTiesToEven);
+ }
+ return {
+ false, value::TypeTags::NumberDouble, value::bitcastFrom<double>(asDec.toDouble())};
+ }
+ case value::TypeTags::NumberInt32:
+ case value::TypeTags::NumberInt64: {
+ if (place >= 0) {
+ return {numOwn, numTag, numVal};
+ }
+ auto numericArgll = numTag == value::TypeTags::NumberInt32
+ ? static_cast<int64_t>(value::bitcastTo<int32_t>(numVal))
+ : value::bitcastTo<int64_t>(numVal);
+ auto out = Decimal128(numericArgll).quantize(quantum, Decimal128::kRoundTiesToEven);
+ uint32_t flags = 0;
+ auto outll = out.toLong(&flags);
+ uassert(5155302,
+ "Invalid conversion to long during $round.",
+ !Decimal128::hasFlag(flags, Decimal128::kInvalid));
+ if (numTag == value::TypeTags::NumberInt64 ||
+ outll > std::numeric_limits<int32_t>::max()) {
+ // Even if the original was an int to begin with - it has to be a long now.
+ return {false, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(outll)};
+ }
+ return {false,
+ value::TypeTags::NumberInt32,
+ value::bitcastFrom<int32_t>(static_cast<int32_t>(outll))};
}
default:
return {false, value::TypeTags::Nothing, 0};
diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h
index 803102f3760..280417c10df 100644
--- a/src/mongo/db/pipeline/expression.h
+++ b/src/mongo/db/pipeline/expression.h
@@ -3129,9 +3129,7 @@ public:
class ExpressionRound final : public ExpressionRangedArity<ExpressionRound, 1, 2> {
public:
explicit ExpressionRound(ExpressionContext* const expCtx)
- : ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx) {
- expCtx->sbeCompatible = false;
- }
+ : ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx) {}
ExpressionRound(ExpressionContext* const expCtx, ExpressionVector&& children)
: ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx, std::move(children)) {}
diff --git a/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp
index a771269780a..345d6c53c9f 100644
--- a/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp
@@ -195,17 +195,17 @@ optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var) {
}
optimizer::ABT generateABTNonTimestampCheck(optimizer::ProjectionName var) {
- return makeNot(makeABTFunction("isTimestamp"_sd, makeVariable(var)));
+ return makeNot(makeABTFunction("isTimestamp"_sd, makeVariable(std::move(var))));
}
optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var) {
return optimizer::make<optimizer::BinaryOp>(
- optimizer::Operations::Lt, makeVariable(var), optimizer::Constant::int32(0));
+ optimizer::Operations::Lt, makeVariable(std::move(var)), optimizer::Constant::int32(0));
}
optimizer::ABT generateABTNonPositiveCheck(optimizer::ProjectionName var) {
return optimizer::make<optimizer::BinaryOp>(
- optimizer::Operations::Lte, makeVariable(var), optimizer::Constant::int32(0));
+ optimizer::Operations::Lte, makeVariable(std::move(var)), optimizer::Constant::int32(0));
}
optimizer::ABT generateABTPositiveCheck(optimizer::ABT var) {
@@ -214,7 +214,7 @@ optimizer::ABT generateABTPositiveCheck(optimizer::ABT var) {
}
optimizer::ABT generateABTNonNumericCheck(optimizer::ProjectionName var) {
- return makeNot(makeABTFunction("isNumber"_sd, makeVariable(var)));
+ return makeNot(makeABTFunction("isNumber"_sd, makeVariable(std::move(var))));
}
optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var) {
@@ -230,11 +230,11 @@ optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var) {
}
optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var) {
- return makeNot(makeABTFunction("isArray"_sd, makeVariable(var)));
+ return makeNot(makeABTFunction("isArray"_sd, makeVariable(std::move(var))));
}
optimizer::ABT generateABTNonObjectCheck(optimizer::ProjectionName var) {
- return makeNot(makeABTFunction("isObject"_sd, makeVariable(var)));
+ return makeNot(makeABTFunction("isObject"_sd, makeVariable(std::move(var))));
}
optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var) {
@@ -248,8 +248,34 @@ optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::Project
sbe::value::TypeTags::NumberInt32))))));
}
+static optimizer::ABT generateIsIntegralType(const optimizer::ProjectionName& var) {
+ return makeABTFunction("typeMatch"_sd,
+ makeVariable(var),
+ optimizer::Constant::int32(getBSONTypeMask(BSONType::NumberInt) |
+ getBSONTypeMask(BSONType::NumberLong)));
+}
+
+optimizer::ABT generateInvalidRoundPlaceArgCheck(const optimizer::ProjectionName& var) {
+ return makeBalancedBooleanOpTree(
+ optimizer::Operations::Or,
+ {
+ // We can perform our numerical test with trunc. trunc will return nothing if we pass a
+ // non-number to it. We return true if the comparison returns nothing, or if
+ // var != trunc(var), indicating this is not a whole number.
+ makeFillEmpty(
+ optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Neq,
+ makeVariable(var),
+ makeABTFunction("trunc", makeVariable(var))),
+ true),
+ optimizer::make<optimizer::BinaryOp>(
+ optimizer::Operations::Lt, makeVariable(var), optimizer::Constant::int32(-20)),
+ optimizer::make<optimizer::BinaryOp>(
+ optimizer::Operations::Gt, makeVariable(var), optimizer::Constant::int32(100)),
+ });
+}
+
optimizer::ABT generateABTNaNCheck(optimizer::ProjectionName var) {
- return makeABTFunction("isNaN"_sd, makeVariable(var));
+ return makeABTFunction("isNaN"_sd, makeVariable(std::move(var)));
}
optimizer::ABT makeABTFail(ErrorCodes::Error error, StringData errorMessage) {
diff --git a/src/mongo/db/query/sbe_stage_builder_abt_helpers.h b/src/mongo/db/query/sbe_stage_builder_abt_helpers.h
index b002ba10e82..abe33d105dc 100644
--- a/src/mongo/db/query/sbe_stage_builder_abt_helpers.h
+++ b/src/mongo/db/query/sbe_stage_builder_abt_helpers.h
@@ -121,6 +121,11 @@ optimizer::ABT generateABTNonStringCheck(optimizer::ABT var);
optimizer::ABT generateABTNonTimestampCheck(optimizer::ProjectionName var);
optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var);
/**
+ * Generates an ABT to check the given variable is a number between -20 and 100 inclusive, and is a
+ * whole number.
+ */
+optimizer::ABT generateInvalidRoundPlaceArgCheck(const optimizer::ProjectionName& var);
+/**
* Generates an ABT that checks if the input expression is NaN _assuming that_ it has
* already been verified to be numeric.
*/
diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp
index dd8ee3f4655..21aa8acbe65 100644
--- a/src/mongo/db/query/sbe_stage_builder_expression.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp
@@ -2630,7 +2630,42 @@ public:
unsupportedExpression(expr->getOpName());
}
void visit(const ExpressionRound* expr) final {
- unsupportedExpression(expr->getOpName());
+ invariant(expr->getChildren().size() == 1 || expr->getChildren().size() == 2);
+ const bool hasPlaceArg = expr->getChildren().size() == 2;
+ _context->ensureArity(expr->getChildren().size());
+
+ auto inputNumName = makeLocalVariableName(_context->state.frameId(), 0);
+ auto inputPlaceName = makeLocalVariableName(_context->state.frameId(), 0);
+
+ // We always need to validate the number parameter, since it will always exist.
+ std::vector<ABTCaseValuePair> inputValidationCases{
+ generateABTReturnNullIfNullOrMissing(makeVariable(inputNumName)),
+ ABTCaseValuePair{
+ generateABTNonNumericCheck(inputNumName),
+ makeABTFail(ErrorCodes::Error{5155300}, "$round only supports numeric types")}};
+ // Only add these cases if we have a "place" argument.
+ if (hasPlaceArg) {
+ inputValidationCases.emplace_back(
+ generateABTReturnNullIfNullOrMissing(makeVariable(inputPlaceName)));
+ inputValidationCases.emplace_back(
+ generateInvalidRoundPlaceArgCheck(inputPlaceName),
+ makeABTFail(ErrorCodes::Error{5155301},
+ "$round requires \"place\" argument to be "
+ "an integer between -20 and 100"));
+ }
+
+ auto roundExpr = buildABTMultiBranchConditionalFromCaseValuePairs(
+ std::move(inputValidationCases),
+ makeABTFunction("round"_sd, makeVariable(inputNumName), makeVariable(inputPlaceName)));
+
+ // "place" argument defaults to 0.
+ auto placeABT = hasPlaceArg ? _context->popABTExpr() : optimizer::Constant::int32(0);
+ auto inputABT = _context->popABTExpr();
+ pushABT(optimizer::make<optimizer::Let>(
+ std::move(inputNumName),
+ std::move(inputABT),
+ optimizer::make<optimizer::Let>(
+ std::move(inputPlaceName), std::move(placeABT), std::move(roundExpr))));
}
void visit(const ExpressionSplit* expr) final {
invariant(expr->getChildren().size() == 2);