diff options
-rw-r--r-- | jstests/aggregation/expressions/round_trunc.js | 146 | ||||
-rw-r--r-- | jstests/libs/sbe_assert_error_override.js | 6 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/SConscript | 1 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/expression.cpp | 2 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp | 198 | ||||
-rw-r--r-- | src/mongo/db/exec/sbe/vm/vm.cpp | 102 | ||||
-rw-r--r-- | src/mongo/db/pipeline/expression.h | 4 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp | 40 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_abt_helpers.h | 5 | ||||
-rw-r--r-- | src/mongo/db/query/sbe_stage_builder_expression.cpp | 37 |
10 files changed, 440 insertions, 101 deletions
diff --git a/jstests/aggregation/expressions/round_trunc.js b/jstests/aggregation/expressions/round_trunc.js index 1e03eb47e68..97e20cd222f 100644 --- a/jstests/aggregation/expressions/round_trunc.js +++ b/jstests/aggregation/expressions/round_trunc.js @@ -5,66 +5,81 @@ // For assertErrorCode. load("jstests/aggregation/extras/utils.js"); +load("jstests/libs/sbe_assert_error_override.js"); // Override error-code-checking APIs. -var coll = db.server19548; +const coll = db.server19548; coll.drop(); -// Seed collection so that the pipeline will execute. -assert.commandWorked(coll.insert({})); // Helper for testing that op returns expResult. -function testOp(op, expResult) { - var pipeline = [{$project: {_id: 0, result: op}}]; +function testOp(exprName, value, expResult, place) { + coll.drop(); + assert.commandWorked(coll.insert({a: value})); + const project = place === undefined ? {[exprName]: "$a"} : {[exprName]: ["$a", place]}; + const pipeline = [{$project: {_id: 0, result: project}}]; assert.eq(coll.aggregate(pipeline).toArray(), [{result: expResult}]); } +function testRound(value, expResult, place) { + testOp("$round", value, expResult, place); +} + +function testTrunc(value, expResult, place) { + testOp("$trunc", value, expResult, place); +} + // Test $trunc and $round with one argument. -testOp({$trunc: NumberLong(4)}, NumberLong(4)); -testOp({$trunc: NaN}, NaN); -testOp({$trunc: Infinity}, Infinity); -testOp({$trunc: -Infinity}, -Infinity); -testOp({$trunc: null}, null); -testOp({$trunc: -2.0}, -2.0); -testOp({$trunc: 0.9}, 0.0); -testOp({$trunc: -1.2}, -1.0); -testOp({$trunc: NumberDecimal("-1.6")}, NumberDecimal("-1")); - -testOp({$round: NumberLong(4)}, NumberLong(4)); -testOp({$round: NaN}, NaN); -testOp({$round: Infinity}, Infinity); -testOp({$round: -Infinity}, -Infinity); -testOp({$round: null}, null); -testOp({$round: -2.0}, -2.0); -testOp({$round: 0.9}, 1.0); -testOp({$round: -1.2}, -1.0); -testOp({$round: NumberDecimal("-1.6")}, NumberDecimal("-2")); +testTrunc(NumberLong(4), NumberLong(4)); +testTrunc(NumberLong(4), NumberLong(4)); +testTrunc(NaN, NaN); +testTrunc(Infinity, Infinity); +testTrunc(-Infinity, -Infinity); +testTrunc(null, null); +testTrunc(-2.0, -2.0); +testTrunc(0.9, 0.0); +testTrunc(-1.2, -1.0); +testTrunc(NumberDecimal("-1.6"), NumberDecimal("-1")); + +testRound(NumberLong(4), NumberLong(4)); +testRound(NaN, NaN); +testRound(Infinity, Infinity); +testRound(-Infinity, -Infinity); +testRound(null, null); +testRound(-2.0, -2.0); +testRound(0.9, 1.0); +testRound(-1.2, -1.0); +testRound(NumberDecimal("-1.6"), NumberDecimal("-2")); // Test $trunc and $round with two arguments. -testOp({$trunc: [1.298, 0]}, 1); -testOp({$trunc: [1.298, 1]}, 1.2); -testOp({$trunc: [23.298, -1]}, 20); -testOp({$trunc: [NumberDecimal("1.298"), 0]}, NumberDecimal("1")); -testOp({$trunc: [NumberDecimal("1.298"), 1]}, NumberDecimal("1.2")); -testOp({$trunc: [NumberDecimal("23.298"), -1]}, NumberDecimal("2E+1")); -testOp({$trunc: [1.298, 100]}, 1.298); -testOp({$trunc: [NumberDecimal("1.298912343250054252245154325"), NumberLong("20")]}, - NumberDecimal("1.29891234325005425224")); -testOp({$trunc: [NumberDecimal("1.298"), NumberDecimal("100")]}, - NumberDecimal("1.298000000000000000000000000000000")); - -testOp({$round: [1.298, 0]}, 1); -testOp({$round: [1.298, 1]}, 1.3); -testOp({$round: [23.298, -1]}, 20); -testOp({$round: [NumberDecimal("1.298"), 0]}, NumberDecimal("1")); -testOp({$round: [NumberDecimal("1.298"), 1]}, NumberDecimal("1.3")); -testOp({$round: [NumberDecimal("23.298"), -1]}, NumberDecimal("2E+1")); -testOp({$round: [1.298, 100]}, 1.298); -testOp({$round: [NumberDecimal("1.298912343250054252245154325"), NumberLong("20")]}, - NumberDecimal("1.29891234325005425225")); -testOp({$round: [NumberDecimal("1.298"), NumberDecimal("100")]}, - NumberDecimal("1.298000000000000000000000000000000")); +testTrunc(1.298, 1, 0); +testTrunc(1.298, 1.2, 1); +testTrunc(23.298, 20, -1); +testTrunc(NumberDecimal("1.298"), NumberDecimal("1"), 0); +testTrunc(NumberDecimal("1.298"), NumberDecimal("1.2"), 1); +testTrunc(NumberDecimal("23.298"), NumberDecimal("2E+1"), -1); +testTrunc(1.298, 1.298, 100); +testTrunc(NumberDecimal("1.298912343250054252245154325"), + NumberDecimal("1.29891234325005425224"), + NumberLong("20")); +testTrunc(NumberDecimal("1.298"), + NumberDecimal("1.298000000000000000000000000000000"), + NumberDecimal("100")); + +testRound(1.298, 1, 0); +testRound(1.298, 1.3, 1); +testRound(23.298, 20, -1); +testRound(NumberDecimal("1.298"), NumberDecimal("1"), 0); +testRound(NumberDecimal("1.298"), NumberDecimal("1.3"), 1); +testRound(NumberDecimal("23.298"), NumberDecimal("2E+1"), -1); +testRound(1.298, 1.298, 100); +testRound(NumberDecimal("1.298912343250054252245154325"), + NumberDecimal("1.29891234325005425225"), + NumberLong("20")); +testRound(NumberDecimal("1.298"), + NumberDecimal("1.298000000000000000000000000000000"), + NumberDecimal("100")); // Test $round overflow. -testOp({$round: [NumberInt("2147483647"), -1]}, NumberLong("2147483650")); +testRound(NumberInt("2147483647"), NumberLong("2147483650"), -1); assertErrorCode(coll, [{$project: {a: {$round: [NumberLong("9223372036854775806"), -1]}}}], 51080); // Test $trunc and $round with more than 2 arguments. @@ -76,21 +91,24 @@ assertErrorCode(coll, [{$project: {a: {$round: "string"}}}], 51081); assertErrorCode(coll, [{$project: {a: {$trunc: "string"}}}], 51081); // Test NaN and Infinity numeric args. -testOp({$round: [Infinity, 0]}, Infinity); -testOp({$round: [-Infinity, 0]}, -Infinity); -testOp({$round: [NaN, 0]}, NaN); -testOp({$round: [NumberDecimal("Infinity"), 0]}, NumberDecimal("Infinity")); -testOp({$round: [NumberDecimal("-Infinity"), 0]}, NumberDecimal("-Infinity")); -testOp({$round: [NumberDecimal("NaN"), 0]}, NumberDecimal("NaN")); - -testOp({$trunc: [Infinity, 0]}, Infinity); -testOp({$trunc: [-Infinity, 0]}, -Infinity); -testOp({$trunc: [NaN, 0]}, NaN); -testOp({$trunc: [NumberDecimal("Infinity"), 0]}, NumberDecimal("Infinity")); -testOp({$trunc: [NumberDecimal("-Infinity"), 0]}, NumberDecimal("-Infinity")); -testOp({$trunc: [NumberDecimal("NaN"), 0]}, NumberDecimal("NaN")); +testRound(Infinity, Infinity, 0); +testRound(-Infinity, -Infinity, 0); +testRound(NaN, NaN, 0); +testRound(NumberDecimal("Infinity"), NumberDecimal("Infinity"), 0); +testRound(NumberDecimal("-Infinity"), NumberDecimal("-Infinity"), 0); +testRound(NumberDecimal("NaN"), NumberDecimal("NaN"), 0); +testRound(null, null, 1); +testRound(1, null, null); + +testTrunc(Infinity, Infinity, 0); +testTrunc(-Infinity, -Infinity, 0); +testTrunc(NaN, NaN, 0); +testTrunc(NumberDecimal("Infinity"), NumberDecimal("Infinity"), 0); +testTrunc(NumberDecimal("-Infinity"), NumberDecimal("-Infinity"), 0); +testTrunc(NumberDecimal("NaN"), NumberDecimal("NaN"), 0); // Test precision arguments that are out of bounds. +assert.commandWorked(coll.insert({})); assertErrorCode(coll, [{$project: {a: {$round: [1, NumberLong("101")]}}}], 51083); assertErrorCode(coll, [{$project: {a: {$round: [1, NumberLong("-21")]}}}], 51083); assertErrorCode(coll, [{$project: {a: {$round: [1, NumberDecimal("101")]}}}], 51083); @@ -111,4 +129,10 @@ assertErrorCode(coll, [{$project: {a: {$trunc: [1, -21]}}}], 51083); // Test non-integral precision arguments. assertErrorCode(coll, [{$project: {a: {$round: [1, NumberDecimal("1.4")]}}}], 51082); assertErrorCode(coll, [{$project: {a: {$trunc: [1, 10.5]}}}], 51082); +assertErrorCode(coll, [{$project: {a: {$round: [0, NaN]}}}], 31109); +assertErrorCode(coll, [{$project: {a: {$round: [0, NumberDecimal("NaN")]}}}], 51082); +assertErrorCode(coll, [{$project: {a: {$round: [BinData(0, ""), 0]}}}], 51081); +assertErrorCode(coll, [{$project: {a: {$round: [0, BinData(0, "")]}}}], 16004); +assertErrorCode(coll, [{$project: {a: {$round: MinKey}}}], 51081); +assertErrorCode(coll, [{$project: {a: {$round: MaxKey}}}], 51081); }()); diff --git a/jstests/libs/sbe_assert_error_override.js b/jstests/libs/sbe_assert_error_override.js index eafabc51ccf..1cc84cae6b3 100644 --- a/jstests/libs/sbe_assert_error_override.js +++ b/jstests/libs/sbe_assert_error_override.js @@ -138,6 +138,12 @@ const equivalentErrorCodesList = [ [17047, 5126900, 7158100], [17048, 5126900, 7158100], [17049, 5126900, 7158100], + [51081, 5155300], + [51080, 5155302], + [31109, 5155301], + [51082, 5155301], + [51083, 5155301], + [16004, 5155301] ]; // This map is generated based on the contents of 'equivalentErrorCodesList'. This map should _not_ diff --git a/src/mongo/db/exec/sbe/SConscript b/src/mongo/db/exec/sbe/SConscript index 4798734ad48..bef7254de27 100644 --- a/src/mongo/db/exec/sbe/SConscript +++ b/src/mongo/db/exec/sbe/SConscript @@ -198,6 +198,7 @@ env.CppUnitTest( 'expressions/sbe_regex_test.cpp', 'expressions/sbe_replace_one_expression_test.cpp', 'expressions/sbe_reverse_array_builtin_test.cpp', + 'expressions/sbe_round_builtin_test.cpp', 'expressions/sbe_runtime_environment_test.cpp', 'expressions/sbe_set_expressions_test.cpp', 'expressions/sbe_shard_filter_builtin_test.cpp', diff --git a/src/mongo/db/exec/sbe/expressions/expression.cpp b/src/mongo/db/exec/sbe/expressions/expression.cpp index 6a8dacec67a..a6a1e2ae524 100644 --- a/src/mongo/db/exec/sbe/expressions/expression.cpp +++ b/src/mongo/db/exec/sbe/expressions/expression.cpp @@ -682,7 +682,7 @@ static stdx::unordered_map<std::string, BuiltinFn> kBuiltinFunctions = { {"sinh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::sinh, false}}, {"tan", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tan, false}}, {"tanh", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::tanh, false}}, - {"round", BuiltinFn{[](size_t n) { return n == 1; }, vm::Builtin::round, false}}, + {"round", BuiltinFn{[](size_t n) { return n == 1 || n == 2; }, vm::Builtin::round, false}}, {"concat", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concat, false}}, {"concatArrays", BuiltinFn{kAnyNumberOfArgs, vm::Builtin::concatArrays, false}}, {"aggConcatArraysCapped", diff --git a/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp new file mode 100644 index 00000000000..823dcdf7c9a --- /dev/null +++ b/src/mongo/db/exec/sbe/expressions/sbe_round_builtin_test.cpp @@ -0,0 +1,198 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/base/string_data.h" +#include "mongo/db/exec/sbe/expression_test_base.h" +#include "mongo/db/exec/sbe/values/value.h" +#include "mongo/unittest/assert.h" + +namespace mongo::sbe { +namespace { +using namespace value; + +static std::pair<TypeTags, Value> makeDecimal(const std::string& n) { + return makeCopyDecimal(Decimal128(n)); +} + +const std::pair<TypeTags, Value> kNull{TypeTags::Null, 0}; +const std::pair<TypeTags, Value> kNothing{TypeTags::Nothing, 0}; + +/** + * A test for SBE built-in function round with one argument. The "place" argument defaults to 0. + */ +TEST_F(EExpressionTestFixture, RoundOneArg) { + OwnedValueAccessor numAccessor; + auto numSlot = bindAccessor(&numAccessor); + + // Construct an invocation of round function. + auto roundExpression = + sbe::makeE<sbe::EFunction>("round", sbe::makeEs(makeE<EVariable>(numSlot))); + auto compiledRound = compileExpression(*roundExpression); + + struct TestCase { + std::pair<TypeTags, Value> num; + std::pair<TypeTags, Value> result; + }; + + std::vector<TestCase> testCases = { + {makeInt32(0), makeInt32(0)}, + {makeInt32(2), makeInt32(2)}, + {makeInt64(4), makeInt64(4)}, + {makeDouble(-2.0), makeDouble(-2.0)}, + {makeDouble(0.9), makeDouble(1.0)}, + {makeDouble(-1.2), makeDouble(-1.0)}, + {makeDecimal("-1.6"), makeDecimal("-2")}, + {makeDouble(1.298), makeDouble(1.0)}, + {makeDecimal("1.298"), makeDecimal("1")}, + // Infinity cases. + {makeDouble(std::numeric_limits<double>::infinity()), + makeDouble(std::numeric_limits<double>::infinity())}, + // Decimal128 infinity. + {makeDecimal("Infinity"), makeDecimal("Infinity")}, + {makeDecimal("-Infinity"), makeDecimal("-Infinity")}, + // NaN cases. + {makeDouble(std::numeric_limits<double>::quiet_NaN()), + makeDouble(std::numeric_limits<double>::quiet_NaN())}, + {makeDecimal("NaN"), makeDecimal("NaN")}, + // Null case. + {kNull, kNothing}, + // Nothing case. + {kNothing, kNothing}, + }; + + int testNum = 0; + for (auto&& testCase : testCases) { + numAccessor.reset(testCase.num.first, testCase.num.second); + ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second); + + // Execute the round function. + auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get()); + ValueGuard actualResultGuard(resultTag, resultValue); + + auto [compTag, compVal] = + compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second); + ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum; + ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum; + + testNum++; + } +} + +/** + * A test for SBE built-in function round with two arguments. + */ +TEST_F(EExpressionTestFixture, RoundTwoArgs) { + OwnedValueAccessor numAccessor; + auto numSlot = bindAccessor(&numAccessor); + OwnedValueAccessor placeAccessor; + auto placeSlot = bindAccessor(&placeAccessor); + + // Construct an invocation of round function. + auto roundExpression = sbe::makeE<sbe::EFunction>( + "round", sbe::makeEs(makeE<EVariable>(numSlot), makeE<EVariable>(placeSlot))); + auto compiledRound = compileExpression(*roundExpression); + + struct TestCase { + std::pair<TypeTags, Value> num; + std::pair<TypeTags, Value> place; + std::pair<TypeTags, Value> result; + }; + + std::vector<TestCase> testCases = { + {makeInt32(43), makeInt32(-1), makeInt32(40)}, + // Try rounding with different types for the "place" argument. + {makeDouble(1.298), makeInt32(0), makeDouble(1.0)}, + {makeDouble(1.298), makeInt64(0ull), makeDouble(1.0)}, + {makeDouble(1.298), makeDouble(0.0), makeDouble(1.0)}, + // Try rounding with a different value for the "place" argument. + {makeDouble(1.298), makeDouble(1.0), makeDouble(1.3)}, + {makeDouble(23.298), makeDouble(-1.0), makeDouble(20.0)}, + // Decimal tests. + {makeDecimal("1.298"), makeDouble(0.0), makeDecimal("1")}, + {makeDecimal("1.298"), makeDouble(1.0), makeDecimal("1.3")}, + {makeDecimal("23.298"), makeDouble(-1.0), makeDecimal("20.0")}, + {makeDecimal("1.298912343250054252245154325"), + makeDouble(20.0), + makeDecimal("1.29891234325005425225")}, + {makeDecimal("1.298"), makeDouble(100.0), makeDecimal("1.298")}, + // Integer promotion case. + {makeInt32(2147483647), makeDouble(-1), makeInt64(2147483650)}, + // Infinity cases. + {makeDouble(std::numeric_limits<double>::infinity()), + makeDouble(10), + makeDouble(std::numeric_limits<double>::infinity())}, + {makeDouble(std::numeric_limits<double>::infinity()), + makeDouble(-10), + makeDouble(std::numeric_limits<double>::infinity())}, + // Decimal128 infinity. + {makeDecimal("Infinity"), makeDouble(0), makeDecimal("Infinity")}, + {makeDecimal("-Infinity"), makeDouble(0), makeDecimal("-Infinity")}, + {makeDecimal("Infinity"), makeDouble(10), makeDecimal("Infinity")}, + {makeDecimal("Infinity"), makeDouble(-10), makeDecimal("Infinity")}, + // NaN cases. + {makeDouble(std::numeric_limits<double>::quiet_NaN()), + makeDouble(1), + makeDouble(std::numeric_limits<double>::quiet_NaN())}, + {makeDouble(std::numeric_limits<double>::quiet_NaN()), + makeDouble(-2), + makeDouble(std::numeric_limits<double>::quiet_NaN())}, + {makeDecimal("NaN"), makeDouble(0), makeDecimal("NaN")}, + {makeDecimal("NaN"), makeDouble(2), makeDecimal("NaN")}, + // Null cases. + {kNull, makeDouble(-5), kNothing}, + {kNull, makeDouble(5), kNothing}, + {makeDouble(1.1), kNull, kNothing}, + // Nothing cases. + {kNothing, makeDouble(-5), kNothing}, + {kNothing, makeDouble(5), kNothing}, + {makeDouble(1.1), kNothing, kNothing}, + // Try the limits of the "place" arg (-20 and 100). + {makeDouble(1.298), makeDouble(100), makeDouble(1.298)}, + {makeDouble(1.298), makeDouble(-20), makeDouble(0)}}; + + int testNum = 0; + for (auto&& testCase : testCases) { + numAccessor.reset(testCase.num.first, testCase.num.second); + placeAccessor.reset(testCase.place.first, testCase.place.second); + ValueGuard expectedResultGuard(testCase.result.first, testCase.result.second); + + // Execute the round function. + auto [resultTag, resultValue] = runCompiledExpression(compiledRound.get()); + ValueGuard actualResultGuard(resultTag, resultValue); + + auto [compTag, compVal] = + compareValue(resultTag, resultValue, testCase.result.first, testCase.result.second); + ASSERT_EQUALS(compTag, TypeTags::NumberInt32) << "unexpected tag for test " << testNum; + ASSERT_EQUALS(compVal, bitcastFrom<int32_t>(0)) << "unexpected value for test " << testNum; + + testNum++; + } +} +} // namespace +} // namespace mongo::sbe diff --git a/src/mongo/db/exec/sbe/vm/vm.cpp b/src/mongo/db/exec/sbe/vm/vm.cpp index fd8a9561c2e..88a08c30c6e 100644 --- a/src/mongo/db/exec/sbe/vm/vm.cpp +++ b/src/mongo/db/exec/sbe/vm/vm.cpp @@ -3639,39 +3639,85 @@ FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinTanh(ArityType a return genericTanh(operandTag, operandValue); } -FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) { - invariant(arity == 1); - auto [owned, tag, val] = getFromStack(0); - - // Round 'val' to the closest integer, with ties rounding to the closest even integer. - // If 'val' is +Inf, -Inf, or NaN, this function will simply return 'val' as-is. +/** + * Converts a number to int32 assuming the input fits the range. This is used for $round "place" + * argument, which is checked to be a whole number between -20 and 100, but could still be a + * non-int32 type. + */ +static int32_t convertNumericToInt32(const value::TypeTags tag, const value::Value val) { switch (tag) { - case value::TypeTags::NumberInt32: - case value::TypeTags::NumberInt64: - // The value is already an integer, so just return it as-is. - return {false, tag, val}; + case value::TypeTags::NumberInt32: { + return value::bitcastTo<int32_t>(val); + } + case value::TypeTags::NumberInt64: { + return static_cast<int32_t>(value::bitcastTo<int64_t>(val)); + } case value::TypeTags::NumberDouble: { - // std::nearbyint()'s behavior relies on a thread-local "rounding mode", so - // we use boost::numeric::RoundEven<double>::nearbyint() instead. We should - // switch over to roundeven() once it becomes available in the standard library. - // (See https://en.cppreference.com/w/c/experimental/fpext1 for details.) - auto operand = value::bitcastTo<double>(val); - auto rounded = boost::numeric::RoundEven<double>::nearbyint(operand); - return {false, tag, value::bitcastFrom<double>(rounded)}; + return static_cast<int32_t>(value::bitcastTo<double>(val)); } case value::TypeTags::NumberDecimal: { - auto operand = value::bitcastTo<Decimal128>(val); - auto rounded = operand.round(Decimal128::RoundingMode::kRoundTiesToEven); - if (operand.isEqual(rounded)) { - // If the output of rounding is equal to the input, then we can just take - // ownership of 'operand' and return it. (This is more efficient than calling - // makeCopyDecimal(), which would allocate memory on the heap.) - topStack(false, value::TypeTags::Nothing, 0); - return {owned, tag, val}; - } + Decimal128 dec = value::bitcastTo<Decimal128>(val); + return dec.toInt(Decimal128::kRoundTiesToEven); + } + default: + MONGO_UNREACHABLE; + } +} - auto [tag, val] = value::makeCopyDecimal(rounded); - return {true, tag, val}; +FastTuple<bool, value::TypeTags, value::Value> ByteCode::builtinRound(ArityType arity) { + invariant(arity == 1 || arity == 2); + int32_t place = 0; + const auto [numOwn, numTag, numVal] = getFromStack(0); + if (arity == 2) { + const auto [placeOwn, placeTag, placeVal] = getFromStack(1); + if (!value::isNumber(placeTag)) { + return {false, value::TypeTags::Nothing, 0}; + } + place = convertNumericToInt32(placeTag, placeVal); + } + + // Construct 10^-precisionValue, which will be used as the quantize reference. This is passed to + // decimal.quantize() to indicate the precision of our rounding. + const auto quantum = Decimal128(0LL, Decimal128::kExponentBias - place, 0LL, 1LL); + switch (numTag) { + case value::TypeTags::NumberDecimal: { + auto dec = value::bitcastTo<Decimal128>(numVal); + if (!dec.isInfinite()) { + dec = dec.quantize(quantum, Decimal128::kRoundTiesToEven); + } + auto [resultTag, resultValue] = value::makeCopyDecimal(dec); + return {true, resultTag, resultValue}; + } + case value::TypeTags::NumberDouble: { + auto asDec = Decimal128(value::bitcastTo<double>(numVal), Decimal128::kRoundTo34Digits); + if (!asDec.isInfinite()) { + asDec = asDec.quantize(quantum, Decimal128::kRoundTiesToEven); + } + return { + false, value::TypeTags::NumberDouble, value::bitcastFrom<double>(asDec.toDouble())}; + } + case value::TypeTags::NumberInt32: + case value::TypeTags::NumberInt64: { + if (place >= 0) { + return {numOwn, numTag, numVal}; + } + auto numericArgll = numTag == value::TypeTags::NumberInt32 + ? static_cast<int64_t>(value::bitcastTo<int32_t>(numVal)) + : value::bitcastTo<int64_t>(numVal); + auto out = Decimal128(numericArgll).quantize(quantum, Decimal128::kRoundTiesToEven); + uint32_t flags = 0; + auto outll = out.toLong(&flags); + uassert(5155302, + "Invalid conversion to long during $round.", + !Decimal128::hasFlag(flags, Decimal128::kInvalid)); + if (numTag == value::TypeTags::NumberInt64 || + outll > std::numeric_limits<int32_t>::max()) { + // Even if the original was an int to begin with - it has to be a long now. + return {false, value::TypeTags::NumberInt64, value::bitcastFrom<int64_t>(outll)}; + } + return {false, + value::TypeTags::NumberInt32, + value::bitcastFrom<int32_t>(static_cast<int32_t>(outll))}; } default: return {false, value::TypeTags::Nothing, 0}; diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index 803102f3760..280417c10df 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -3129,9 +3129,7 @@ public: class ExpressionRound final : public ExpressionRangedArity<ExpressionRound, 1, 2> { public: explicit ExpressionRound(ExpressionContext* const expCtx) - : ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx) { - expCtx->sbeCompatible = false; - } + : ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx) {} ExpressionRound(ExpressionContext* const expCtx, ExpressionVector&& children) : ExpressionRangedArity<ExpressionRound, 1, 2>(expCtx, std::move(children)) {} diff --git a/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp b/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp index a771269780a..345d6c53c9f 100644 --- a/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp +++ b/src/mongo/db/query/sbe_stage_builder_abt_helpers.cpp @@ -195,17 +195,17 @@ optimizer::ABT generateABTNonStringCheck(optimizer::ProjectionName var) { } optimizer::ABT generateABTNonTimestampCheck(optimizer::ProjectionName var) { - return makeNot(makeABTFunction("isTimestamp"_sd, makeVariable(var))); + return makeNot(makeABTFunction("isTimestamp"_sd, makeVariable(std::move(var)))); } optimizer::ABT generateABTNegativeCheck(optimizer::ProjectionName var) { return optimizer::make<optimizer::BinaryOp>( - optimizer::Operations::Lt, makeVariable(var), optimizer::Constant::int32(0)); + optimizer::Operations::Lt, makeVariable(std::move(var)), optimizer::Constant::int32(0)); } optimizer::ABT generateABTNonPositiveCheck(optimizer::ProjectionName var) { return optimizer::make<optimizer::BinaryOp>( - optimizer::Operations::Lte, makeVariable(var), optimizer::Constant::int32(0)); + optimizer::Operations::Lte, makeVariable(std::move(var)), optimizer::Constant::int32(0)); } optimizer::ABT generateABTPositiveCheck(optimizer::ABT var) { @@ -214,7 +214,7 @@ optimizer::ABT generateABTPositiveCheck(optimizer::ABT var) { } optimizer::ABT generateABTNonNumericCheck(optimizer::ProjectionName var) { - return makeNot(makeABTFunction("isNumber"_sd, makeVariable(var))); + return makeNot(makeABTFunction("isNumber"_sd, makeVariable(std::move(var)))); } optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var) { @@ -230,11 +230,11 @@ optimizer::ABT generateABTLongLongMinCheck(optimizer::ProjectionName var) { } optimizer::ABT generateABTNonArrayCheck(optimizer::ProjectionName var) { - return makeNot(makeABTFunction("isArray"_sd, makeVariable(var))); + return makeNot(makeABTFunction("isArray"_sd, makeVariable(std::move(var)))); } optimizer::ABT generateABTNonObjectCheck(optimizer::ProjectionName var) { - return makeNot(makeABTFunction("isObject"_sd, makeVariable(var))); + return makeNot(makeABTFunction("isObject"_sd, makeVariable(std::move(var)))); } optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var) { @@ -248,8 +248,34 @@ optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::Project sbe::value::TypeTags::NumberInt32)))))); } +static optimizer::ABT generateIsIntegralType(const optimizer::ProjectionName& var) { + return makeABTFunction("typeMatch"_sd, + makeVariable(var), + optimizer::Constant::int32(getBSONTypeMask(BSONType::NumberInt) | + getBSONTypeMask(BSONType::NumberLong))); +} + +optimizer::ABT generateInvalidRoundPlaceArgCheck(const optimizer::ProjectionName& var) { + return makeBalancedBooleanOpTree( + optimizer::Operations::Or, + { + // We can perform our numerical test with trunc. trunc will return nothing if we pass a + // non-number to it. We return true if the comparison returns nothing, or if + // var != trunc(var), indicating this is not a whole number. + makeFillEmpty( + optimizer::make<optimizer::BinaryOp>(optimizer::Operations::Neq, + makeVariable(var), + makeABTFunction("trunc", makeVariable(var))), + true), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Lt, makeVariable(var), optimizer::Constant::int32(-20)), + optimizer::make<optimizer::BinaryOp>( + optimizer::Operations::Gt, makeVariable(var), optimizer::Constant::int32(100)), + }); +} + optimizer::ABT generateABTNaNCheck(optimizer::ProjectionName var) { - return makeABTFunction("isNaN"_sd, makeVariable(var)); + return makeABTFunction("isNaN"_sd, makeVariable(std::move(var))); } optimizer::ABT makeABTFail(ErrorCodes::Error error, StringData errorMessage) { diff --git a/src/mongo/db/query/sbe_stage_builder_abt_helpers.h b/src/mongo/db/query/sbe_stage_builder_abt_helpers.h index b002ba10e82..abe33d105dc 100644 --- a/src/mongo/db/query/sbe_stage_builder_abt_helpers.h +++ b/src/mongo/db/query/sbe_stage_builder_abt_helpers.h @@ -121,6 +121,11 @@ optimizer::ABT generateABTNonStringCheck(optimizer::ABT var); optimizer::ABT generateABTNonTimestampCheck(optimizer::ProjectionName var); optimizer::ABT generateABTNullishOrNotRepresentableInt32Check(optimizer::ProjectionName var); /** + * Generates an ABT to check the given variable is a number between -20 and 100 inclusive, and is a + * whole number. + */ +optimizer::ABT generateInvalidRoundPlaceArgCheck(const optimizer::ProjectionName& var); +/** * Generates an ABT that checks if the input expression is NaN _assuming that_ it has * already been verified to be numeric. */ diff --git a/src/mongo/db/query/sbe_stage_builder_expression.cpp b/src/mongo/db/query/sbe_stage_builder_expression.cpp index dd8ee3f4655..21aa8acbe65 100644 --- a/src/mongo/db/query/sbe_stage_builder_expression.cpp +++ b/src/mongo/db/query/sbe_stage_builder_expression.cpp @@ -2630,7 +2630,42 @@ public: unsupportedExpression(expr->getOpName()); } void visit(const ExpressionRound* expr) final { - unsupportedExpression(expr->getOpName()); + invariant(expr->getChildren().size() == 1 || expr->getChildren().size() == 2); + const bool hasPlaceArg = expr->getChildren().size() == 2; + _context->ensureArity(expr->getChildren().size()); + + auto inputNumName = makeLocalVariableName(_context->state.frameId(), 0); + auto inputPlaceName = makeLocalVariableName(_context->state.frameId(), 0); + + // We always need to validate the number parameter, since it will always exist. + std::vector<ABTCaseValuePair> inputValidationCases{ + generateABTReturnNullIfNullOrMissing(makeVariable(inputNumName)), + ABTCaseValuePair{ + generateABTNonNumericCheck(inputNumName), + makeABTFail(ErrorCodes::Error{5155300}, "$round only supports numeric types")}}; + // Only add these cases if we have a "place" argument. + if (hasPlaceArg) { + inputValidationCases.emplace_back( + generateABTReturnNullIfNullOrMissing(makeVariable(inputPlaceName))); + inputValidationCases.emplace_back( + generateInvalidRoundPlaceArgCheck(inputPlaceName), + makeABTFail(ErrorCodes::Error{5155301}, + "$round requires \"place\" argument to be " + "an integer between -20 and 100")); + } + + auto roundExpr = buildABTMultiBranchConditionalFromCaseValuePairs( + std::move(inputValidationCases), + makeABTFunction("round"_sd, makeVariable(inputNumName), makeVariable(inputPlaceName))); + + // "place" argument defaults to 0. + auto placeABT = hasPlaceArg ? _context->popABTExpr() : optimizer::Constant::int32(0); + auto inputABT = _context->popABTExpr(); + pushABT(optimizer::make<optimizer::Let>( + std::move(inputNumName), + std::move(inputABT), + optimizer::make<optimizer::Let>( + std::move(inputPlaceName), std::move(placeABT), std::move(roundExpr)))); } void visit(const ExpressionSplit* expr) final { invariant(expr->getChildren().size() == 2); |