diff options
author | David Percy <david.percy@mongodb.com> | 2021-02-16 23:00:38 +0000 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-03-04 22:11:37 +0000 |
commit | 8641dd510c5d2b6fe8f58f221ff4fe724f9267a7 (patch) | |
tree | 92edc79d46c313b55aec48a1c7b5b20c87a6fed3 | |
parent | fccad833c9081bf0cd61364afdb4ec01ceeb42fa (diff) | |
download | mongo-8641dd510c5d2b6fe8f58f221ff4fe724f9267a7.tar.gz |
SERVER-54233 Implement $derivative window-function executor
24 files changed, 827 insertions, 104 deletions
diff --git a/jstests/aggregation/bugs/server6239.js b/jstests/aggregation/bugs/server6239.js index a88bf7527d6..c20117651f8 100644 --- a/jstests/aggregation/bugs/server6239.js +++ b/jstests/aggregation/bugs/server6239.js @@ -22,7 +22,7 @@ function fail(expression, code) { test({$subtract: ['$date', '$date']}, NumberLong(0)); test({$subtract: ['$date', '$num']}, new Date(millis - num)); -fail({$subtract: ['$num', '$date']}, 16556); +fail({$subtract: ['$num', '$date']}, [16556, ErrorCodes.TypeMismatch]); fail({$add: ['$date', '$date']}, 16612); test({$add: ['$date', '$num']}, new Date(millis + num)); diff --git a/jstests/aggregation/bugs/server6240.js b/jstests/aggregation/bugs/server6240.js index 8a194e82a99..1a162b9e3da 100644 --- a/jstests/aggregation/bugs/server6240.js +++ b/jstests/aggregation/bugs/server6240.js @@ -28,13 +28,16 @@ db.s6240.save({date: new Date()}); assertErrorCode(db.s6240, {$project: {add: {$add: ["$date", "$date"]}}}, 16612); // Divide -assertErrorCode(db.s6240, {$project: {divide: {$divide: ["$date", 2]}}}, 16609); +assertErrorCode( + db.s6240, {$project: {divide: {$divide: ["$date", 2]}}}, [16609, ErrorCodes.TypeMismatch]); // Mod assertErrorCode(db.s6240, {$project: {mod: {$mod: ["$date", 2]}}}, 16611); // Multiply -assertErrorCode(db.s6240, {$project: {multiply: {$multiply: ["$date", 2]}}}, 16555); +assertErrorCode( + db.s6240, {$project: {multiply: {$multiply: ["$date", 2]}}}, [16555, ErrorCodes.TypeMismatch]); // Subtract -assertErrorCode(db.s6240, {$project: {subtract: {$subtract: [2, "$date"]}}}, 16556); +assertErrorCode( + db.s6240, {$project: {subtract: {$subtract: [2, "$date"]}}}, [16556, ErrorCodes.TypeMismatch]); diff --git a/jstests/aggregation/expressions/divide.js b/jstests/aggregation/expressions/divide.js index e89cd0e23b1..ddd404ec055 100644 --- a/jstests/aggregation/expressions/divide.js +++ b/jstests/aggregation/expressions/divide.js @@ -76,20 +76,20 @@ testCases.forEach(function(testCase) { // Test error codes on incorrect use of $divide. const errorTestCases = [ - {document: {left: 1, right: NumberInt(0)}, errorCode: 16608}, - {document: {left: 1, right: 0.0}, errorCode: 16608}, - {document: {left: 1, right: NumberLong("0")}, errorCode: 16608}, - {document: {left: 1, right: NumberDecimal("0")}, errorCode: 16608}, + {document: {left: 1, right: NumberInt(0)}, errorCodes: [16608, ErrorCodes.BadValue]}, + {document: {left: 1, right: 0.0}, errorCodes: [16608, ErrorCodes.BadValue]}, + {document: {left: 1, right: NumberLong("0")}, errorCodes: [16608, ErrorCodes.BadValue]}, + {document: {left: 1, right: NumberDecimal("0")}, errorCodes: [16608, ErrorCodes.BadValue]}, - {document: {left: 1, right: "not a number"}, errorCode: 16609}, - {document: {left: "not a number", right: 1}, errorCode: 16609}, + {document: {left: 1, right: "not a number"}, errorCodes: [16609, ErrorCodes.TypeMismatch]}, + {document: {left: "not a number", right: 1}, errorCodes: [16609, ErrorCodes.TypeMismatch]}, ]; errorTestCases.forEach(function(testCase) { assert.commandWorked(coll.insert(testCase.document)); assertErrorCode( - coll, {$project: {computed: {$divide: ["$left", "$right"]}}}, testCase.errorCode); + coll, {$project: {computed: {$divide: ["$left", "$right"]}}}, testCase.errorCodes); assert(coll.drop()); }); diff --git a/jstests/aggregation/expressions/multiply.js b/jstests/aggregation/expressions/multiply.js index 2859046bc73..9fa789c3cac 100644 --- a/jstests/aggregation/expressions/multiply.js +++ b/jstests/aggregation/expressions/multiply.js @@ -145,16 +145,16 @@ nAryTestCases.forEach(function(testCase) { // Test error codes on incorrect use of $multiply. const errorTestCases = [ - {document: {left: 1, right: "not a number"}, errorCode: 16555}, - {document: {left: "not a number", right: 1}, errorCode: 16555}, + {document: {left: 1, right: "not a number"}, errorCodes: [16555, ErrorCodes.TypeMismatch]}, + {document: {left: "not a number", right: 1}, errorCodes: [16555, ErrorCodes.TypeMismatch]}, ]; errorTestCases.forEach(function(testCase) { assert.commandWorked(coll.insert(testCase.document)); assertErrorCode( - coll, {$project: {computed: {$multiply: ["$left", "$right"]}}}, testCase.errorCode); + coll, {$project: {computed: {$multiply: ["$left", "$right"]}}}, testCase.errorCodes); assert(coll.drop()); }); -}());
\ No newline at end of file +}()); diff --git a/jstests/aggregation/ifnull.js b/jstests/aggregation/ifnull.js index 392629be9a1..032deaa48ab 100644 --- a/jstests/aggregation/ifnull.js +++ b/jstests/aggregation/ifnull.js @@ -66,7 +66,7 @@ assertResult(6, ['$missingField', {$multiply: ['$two', '$three']}]); assertResult(2, [{$add: ['$one', '$one']}, '$three', '$zero']); // Divide/mod by 0. -assertError([16608, 4848401], [{$divide: ['$one', '$zero']}, '$zero']); +assertError([ErrorCodes.BadValue, 16608, 4848401], [{$divide: ['$one', '$zero']}, '$zero']); assertError([16610, 4848403], [{$mod: ['$one', '$zero']}, '$zero']); // Return undefined. @@ -80,4 +80,4 @@ assertResult('foo', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]); assert.commandWorked(t.updateMany({}, {$set: {b: 'bar'}})); assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}]); assertResult('bar', ['$a', {$ifNull: ['$b', {$ifNull: ['$c', '$d']}]}, '$e']); -}());
\ No newline at end of file +}()); diff --git a/jstests/aggregation/sources/graphLookup/error.js b/jstests/aggregation/sources/graphLookup/error.js index 76245fe5ebb..c74fadde4ee 100644 --- a/jstests/aggregation/sources/graphLookup/error.js +++ b/jstests/aggregation/sources/graphLookup/error.js @@ -320,7 +320,7 @@ pipeline = { restrictSearchWithMatch: {$expr: {$divide: [1, "$x"]}} } }; -assertErrorCode(local, pipeline, 16608, "division by zero in $expr"); +assertErrorCode(local, pipeline, [16608, ErrorCodes.BadValue], "division by zero in $expr"); // $graphLookup can only consume at most 100MB of memory. foreign.drop(); diff --git a/jstests/core/expr.js b/jstests/core/expr.js index dabba0f922c..362e06c9a55 100644 --- a/jstests/core/expr.js +++ b/jstests/core/expr.js @@ -125,7 +125,7 @@ let explain = coll.find({$expr: {$divide: [1, "$a"]}}).explain("executionStats") if (!isMongos) { assert(explain.hasOwnProperty("executionStats"), explain); assert.eq(explain.executionStats.executionSuccess, false, explain); - assert.errorCodeEq(explain.executionStats.errorCode, 16609, explain); + assert.errorCodeEq(explain.executionStats.errorCode, [16609, ErrorCodes.TypeMismatch], explain); } // $expr is not allowed in $elemMatch projection. diff --git a/jstests/multiVersion/doc_validation_error_upgrade_downgrade.js b/jstests/multiVersion/doc_validation_error_upgrade_downgrade.js index a8367052df8..a420a7cb428 100644 --- a/jstests/multiVersion/doc_validation_error_upgrade_downgrade.js +++ b/jstests/multiVersion/doc_validation_error_upgrade_downgrade.js @@ -130,7 +130,7 @@ const exprValidator = { }; assert.commandWorked(targetDB.runCommand({collMod: collName, validator: exprValidator})); let exprResponse = targetDB[collName].insert({}); -assert.commandFailedWithCode(exprResponse, 16608, tojson(exprResponse)); +assert.commandFailedWithCode(exprResponse, [16608, ErrorCodes.BadValue], tojson(exprResponse)); // Verify that the insert succeeds when the validator expression throws if the validationAction is // set to 'warn' and the FCV is 4.4. @@ -207,4 +207,4 @@ st.upgradeCluster("last-lts", {upgradeShards: true, upgradeConfigs: true, upgrad testDocumentValidation(sourceDB, targetDB, assertFCV44DocumentValidationFailure); st.stop(); })(); -})();
\ No newline at end of file +})(); diff --git a/jstests/sharding/query/agg_error_reports_shard_host_and_port.js b/jstests/sharding/query/agg_error_reports_shard_host_and_port.js index 17b5546e78f..f6115215846 100644 --- a/jstests/sharding/query/agg_error_reports_shard_host_and_port.js +++ b/jstests/sharding/query/agg_error_reports_shard_host_and_port.js @@ -26,9 +26,9 @@ assert.commandWorked(coll.insert({_id: 0})); // sent to the shard before failing (i.e. "$_id") so that mongos doesn't short-curcuit and // fail during optimization. const pipe = [{$project: {a: {$divide: ["$_id", 0]}}}]; -const divideByZeroErrorCode = 16608; +const divideByZeroErrorCodes = [16608, ErrorCodes.BadValue]; -assertErrCodeAndErrMsgContains(coll, pipe, divideByZeroErrorCode, st.rs1.getPrimary().host); +assertErrCodeAndErrMsgContains(coll, pipe, divideByZeroErrorCodes, st.rs1.getPrimary().host); st.stop(); }()); diff --git a/jstests/sharding/query/sharded_agg_cleanup_on_error.js b/jstests/sharding/query/sharded_agg_cleanup_on_error.js index cb10cb2323c..9180ad042f2 100644 --- a/jstests/sharding/query/sharded_agg_cleanup_on_error.js +++ b/jstests/sharding/query/sharded_agg_cleanup_on_error.js @@ -21,7 +21,7 @@ const kFailpointOptions = { const st = new ShardingTest({shards: 2}); const kDBName = "test"; -const kDivideByZeroErrCode = 16608; +const kDivideByZeroErrCodes = [16608, ErrorCodes.BadValue]; const mongosDB = st.s.getDB(kDBName); const shard0DB = st.shard0.getDB(kDBName); const shard1DB = st.shard1.getDB(kDBName); @@ -89,7 +89,7 @@ try { {$project: {out: {$divide: ["$_id", 0]}}}, {$_internalSplitPipeline: {mergeType: "mongos"}} ], - errCode: kDivideByZeroErrCode + errCode: kDivideByZeroErrCodes }); // Repeat the test above, but this time use $_internalSplitPipeline to force the merge to @@ -99,7 +99,7 @@ try { {$project: {out: {$divide: ["$_id", 0]}}}, {$_internalSplitPipeline: {mergeType: "primaryShard"}} ], - errCode: kDivideByZeroErrCode + errCode: kDivideByZeroErrCodes }); } finally { assert.commandWorked(shard1DB.adminCommand({configureFailPoint: kFailPointName, mode: "off"})); diff --git a/src/mongo/db/matcher/doc_validation_error_test.cpp b/src/mongo/db/matcher/doc_validation_error_test.cpp index b34cda4d391..44eef6ac54f 100644 --- a/src/mongo/db/matcher/doc_validation_error_test.cpp +++ b/src/mongo/db/matcher/doc_validation_error_test.cpp @@ -953,8 +953,8 @@ TEST(MiscellaneousMatchExpression, ExprWhichThrowsGeneratesError) { "specifiedAs: {$expr: {$divide: [10, 0]}}," "reason: 'failed to evaluate aggregation expression'," "details: " - " {code: 16608, " - " codeName: 'Location16608', " + " {code: 2, " + " codeName: 'BadValue', " " errmsg: \"can't $divide by zero\"}}"); doc_validation_error::verifyGeneratedError(query, doc, expectedError, true /* shouldThrow */); } @@ -978,8 +978,8 @@ TEST(MiscellaneousMatchExpression, MultipleExprsWhichThrow) { " specifiedAs: {$expr: {$divide: [10, 0]}}," " reason: 'failed to evaluate aggregation expression'," " details: " - " {code: 16608, " - " codeName: 'Location16608', " + " {code: 2, " + " codeName: 'BadValue', " " errmsg: \"can't $divide by zero\"}}}]}"); doc_validation_error::verifyGeneratedError(query, doc, expectedError, true /* shouldThrow */); } @@ -995,8 +995,8 @@ TEST(MiscellaneousMatchExpression, OneExprThrowsAmongMultiple) { " specifiedAs: {$expr: {$divide: [10, 0]}}," " reason: 'failed to evaluate aggregation expression'," " details: " - " {code: 16608, " - " codeName: 'Location16608', " + " {code: 2, " + " codeName: 'BadValue', " " errmsg: \"can't $divide by zero\"}}}]}"); doc_validation_error::verifyGeneratedError(query, doc, expectedError, true /* shouldThrow */); } @@ -1020,8 +1020,8 @@ TEST(MiscellaneousMatchExpression, ExprsWhichThrowUnderInversion) { " specifiedAs: {$expr: {$divide: [10, 0]}}," " reason: 'failed to evaluate aggregation expression'," " details: " - " {code: 16608, " - " codeName: 'Location16608', " + " {code: 2, " + " codeName: 'BadValue', " " errmsg: \"can't $divide by zero\"}}}]}"); doc_validation_error::verifyGeneratedError(query, doc, expectedError, true /* shouldThrow */); } diff --git a/src/mongo/db/matcher/expression_expr_test.cpp b/src/mongo/db/matcher/expression_expr_test.cpp index 6f88314a6f5..5ff0cbf88e9 100644 --- a/src/mongo/db/matcher/expression_expr_test.cpp +++ b/src/mongo/db/matcher/expression_expr_test.cpp @@ -597,7 +597,7 @@ TEST_F(ExprMatchTest, FailGracefullyOnInvalidExpression) { TEST_F(ExprMatchTest, ReturnsFalseInsteadOfErrorWithFailpointSet) { createMatcher(fromjson("{$expr: {$divide: [10, '$divisor']}}")); - ASSERT_THROWS_CODE(matches(BSON("divisor" << 0)), AssertionException, 16608); + ASSERT_THROWS_CODE(matches(BSON("divisor" << 0)), AssertionException, ErrorCodes::BadValue); FailPointEnableBlock scopedFailpoint("ExprMatchExpressionMatchesReturnsFalseOnException"); createMatcher(fromjson("{$expr: {$divide: [10, '$divisor']}}")); diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 4914a74c8aa..6ac3ef693bf 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -264,6 +264,7 @@ pipelineEnv.Library( 'tee_buffer.cpp', 'window_function/partition_iterator.cpp', 'window_function/window_function_exec.cpp', + 'window_function/window_function_exec_derivative.cpp', 'window_function/window_function_exec_removable_document.cpp', ], LIBDEPS=[ @@ -428,7 +429,8 @@ env.CppUnitTest( 'tee_buffer_test.cpp', 'window_function/partition_iterator_test.cpp', 'window_function/window_function_add_to_set_test.cpp', - 'window_function/window_function_exec_test.cpp', + 'window_function/window_function_exec_derivative_test.cpp', + 'window_function/window_function_exec_non_removable_test.cpp', 'window_function/window_function_min_max_test.cpp', 'window_function/window_function_push_test.cpp', ], diff --git a/src/mongo/db/pipeline/expression.cpp b/src/mongo/db/pipeline/expression.cpp index f0d5593a71c..0db59301a5c 100644 --- a/src/mongo/db/pipeline/expression.cpp +++ b/src/mongo/db/pipeline/expression.cpp @@ -2048,31 +2048,34 @@ void ExpressionDateDiff::_doAddDependencies(DepsTracker* deps) const { /* ----------------------- ExpressionDivide ---------------------------- */ Value ExpressionDivide::evaluate(const Document& root, Variables* variables) const { - Value lhs = _children[0]->evaluate(root, variables); - Value rhs = _children[1]->evaluate(root, variables); - - auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); }; + return uassertStatusOK( + apply(_children[0]->evaluate(root, variables), _children[1]->evaluate(root, variables))); +} +StatusWith<Value> ExpressionDivide::apply(Value lhs, Value rhs) { if (lhs.numeric() && rhs.numeric()) { // If, and only if, either side is decimal, return decimal. if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) { Decimal128 numer = lhs.coerceToDecimal(); Decimal128 denom = rhs.coerceToDecimal(); - assertNonZero(!denom.isZero()); + if (denom.isZero()) + return Status(ErrorCodes::BadValue, "can't $divide by zero"); return Value(numer.divide(denom)); } double numer = lhs.coerceToDouble(); double denom = rhs.coerceToDouble(); - assertNonZero(denom != 0.0); + if (denom == 0.0) + return Status(ErrorCodes::BadValue, "can't $divide by zero"); return Value(numer / denom); } else if (lhs.nullish() || rhs.nullish()) { return Value(BSONNULL); } else { - uasserted(16609, - str::stream() << "$divide only supports numeric types, not " - << typeName(lhs.getType()) << " and " << typeName(rhs.getType())); + return Status(ErrorCodes::TypeMismatch, + str::stream() + << "$divide only supports numeric types, not " << typeName(lhs.getType()) + << " and " << typeName(rhs.getType())); } } @@ -2972,63 +2975,94 @@ const char* ExpressionMod::getOpName() const { /* ------------------------- ExpressionMultiply ----------------------------- */ -Value ExpressionMultiply::evaluate(const Document& root, Variables* variables) const { - /* - We'll try to return the narrowest possible result value. To do that - without creating intermediate Values, do the arithmetic for double - and integral types in parallel, tracking the current narrowest - type. +namespace { +class MultiplyState { + /** + * We'll try to return the narrowest possible result value. To do that without creating + * intermediate Values, do the arithmetic for double and integral types in parallel, tracking + * the current narrowest type. */ double doubleProduct = 1; long long longProduct = 1; Decimal128 decimalProduct; // This will be initialized on encountering the first decimal. - BSONType productType = NumberInt; - const size_t n = _children.size(); - for (size_t i = 0; i < n; ++i) { - Value val = _children[i]->evaluate(root, variables); - - if (val.numeric()) { - BSONType oldProductType = productType; - productType = Value::getWidestNumeric(productType, val.getType()); - if (productType == NumberDecimal) { - // On finding the first decimal, convert the partial product to decimal. - if (oldProductType != NumberDecimal) { - decimalProduct = oldProductType == NumberDouble - ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits) - : Decimal128(static_cast<int64_t>(longProduct)); - } - decimalProduct = decimalProduct.multiply(val.coerceToDecimal()); - } else { - doubleProduct *= val.coerceToDouble(); - - if (!std::isfinite(val.coerceToDouble()) || - overflow::mul(longProduct, val.coerceToLong(), &longProduct)) { - // The number is either Infinity or NaN, or the 'longProduct' would have - // overflowed, so we're abandoning it. - productType = NumberDouble; - } +public: + void operator*=(const Value& val) { + tassert(5423304, "MultiplyState::operator*= only supports numbers", val.numeric()); + + BSONType oldProductType = productType; + productType = Value::getWidestNumeric(productType, val.getType()); + if (productType == NumberDecimal) { + // On finding the first decimal, convert the partial product to decimal. + if (oldProductType != NumberDecimal) { + decimalProduct = oldProductType == NumberDouble + ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits) + : Decimal128(static_cast<int64_t>(longProduct)); } - } else if (val.nullish()) { - return Value(BSONNULL); + decimalProduct = decimalProduct.multiply(val.coerceToDecimal()); } else { - uasserted(16555, - str::stream() << "$multiply only supports numeric types, not " - << typeName(val.getType())); + doubleProduct *= val.coerceToDouble(); + + if (!std::isfinite(val.coerceToDouble()) || + overflow::mul(longProduct, val.coerceToLong(), &longProduct)) { + // The number is either Infinity or NaN, or the 'longProduct' would have + // overflowed, so we're abandoning it. + productType = NumberDouble; + } } } - if (productType == NumberDouble) - return Value(doubleProduct); - else if (productType == NumberLong) - return Value(longProduct); - else if (productType == NumberInt) - return Value::createIntOrLong(longProduct); - else if (productType == NumberDecimal) - return Value(decimalProduct); - else - massert(16418, "$multiply resulted in a non-numeric type", false); + Value getValue() const { + if (productType == NumberDouble) + return Value(doubleProduct); + else if (productType == NumberLong) + return Value(longProduct); + else if (productType == NumberInt) + return Value::createIntOrLong(longProduct); + else if (productType == NumberDecimal) + return Value(decimalProduct); + else + massert(16418, "$multiply resulted in a non-numeric type", false); + } +}; + +Status checkMultiplyNumeric(Value val) { + if (!val.numeric()) + return Status(ErrorCodes::TypeMismatch, + str::stream() << "$multiply only supports numeric types, not " + << typeName(val.getType())); + return Status::OK(); +} +} // namespace + +StatusWith<Value> ExpressionMultiply::apply(Value lhs, Value rhs) { + // evaluate() checks arguments left-to-right, short circuiting on the first null or non-number. + // Imitate that behavior here. + if (lhs.nullish()) + return Value(BSONNULL); + if (Status s = checkMultiplyNumeric(lhs); !s.isOK()) + return s; + if (rhs.nullish()) + return Value(BSONNULL); + if (Status s = checkMultiplyNumeric(rhs); !s.isOK()) + return s; + + MultiplyState state; + state *= lhs; + state *= rhs; + return state.getValue(); +} +Value ExpressionMultiply::evaluate(const Document& root, Variables* variables) const { + MultiplyState state; + for (auto&& child : _children) { + Value val = child->evaluate(root, variables); + if (val.nullish()) + return Value(BSONNULL); + uassertStatusOK(checkMultiplyNumeric(val)); + state *= child->evaluate(root, variables); + } + return state.getValue(); } REGISTER_EXPRESSION(multiply, ExpressionMultiply::parse); @@ -4914,9 +4948,11 @@ const char* ExpressionStrLenCP::getOpName() const { /* ----------------------- ExpressionSubtract ---------------------------- */ Value ExpressionSubtract::evaluate(const Document& root, Variables* variables) const { - Value lhs = _children[0]->evaluate(root, variables); - Value rhs = _children[1]->evaluate(root, variables); + return uassertStatusOK( + apply(_children[0]->evaluate(root, variables), _children[1]->evaluate(root, variables))); +} +StatusWith<Value> ExpressionSubtract::apply(Value lhs, Value rhs) { BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType()); if (diffType == NumberDecimal) { @@ -4947,14 +4983,14 @@ Value ExpressionSubtract::evaluate(const Document& root, Variables* variables) c } else if (rhs.numeric()) { return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong())); } else { - uasserted(16613, - str::stream() - << "cant $subtract a " << typeName(rhs.getType()) << " from a Date"); + return Status(ErrorCodes::TypeMismatch, + str::stream() + << "cant $subtract a " << typeName(rhs.getType()) << " from a Date"); } } else { - uasserted(16556, - str::stream() << "cant $subtract a" << typeName(rhs.getType()) << " from a " - << typeName(lhs.getType())); + return Status(ErrorCodes::TypeMismatch, + str::stream() << "cant $subtract a" << typeName(rhs.getType()) << " from a " + << typeName(lhs.getType())); } } diff --git a/src/mongo/db/pipeline/expression.h b/src/mongo/db/pipeline/expression.h index b8e6b3261bb..1026bcae7be 100644 --- a/src/mongo/db/pipeline/expression.h +++ b/src/mongo/db/pipeline/expression.h @@ -1453,6 +1453,16 @@ private: class ExpressionDivide final : public ExpressionFixedArity<ExpressionDivide, 2> { public: + /** + * Divides two values as if by {$divide: [{$const: numerator}, {$const: denominator]}. + * + * Returns BSONNULL if either argument is nullish. + * + * Returns ErrorCodes::TypeMismatch if either argument is non-nullish and non-numeric. + * Returns ErrorCodes::BadValue if the denominator is zero. + */ + static StatusWith<Value> apply(Value numerator, Value denominator); + explicit ExpressionDivide(ExpressionContext* const expCtx) : ExpressionFixedArity<ExpressionDivide, 2>(expCtx) {} explicit ExpressionDivide(ExpressionContext* const expCtx, ExpressionVector&& children) @@ -1985,6 +1995,19 @@ public: class ExpressionMultiply final : public ExpressionVariadic<ExpressionMultiply> { public: + /** + * Multiplies two values together as if by evaluate() on + * {$multiply: [{$const: lhs}, {$const: rhs}]}. + * + * Note that evaluate() does not use apply() directly, because when $muliply takes more than + * two arguments, it uses a wider intermediate state than Value. + * + * Returns BSONNULL if either argument is nullish. + * + * Returns ErrorCodes::TypeMismatch if any argument is non-nullish, non-numeric. + */ + static StatusWith<Value> apply(Value lhs, Value rhs); + explicit ExpressionMultiply(ExpressionContext* const expCtx) : ExpressionVariadic<ExpressionMultiply>(expCtx) {} ExpressionMultiply(ExpressionContext* const expCtx, ExpressionVector&& children) @@ -2631,6 +2654,20 @@ public: class ExpressionSubtract final : public ExpressionFixedArity<ExpressionSubtract, 2> { public: + /** + * Subtracts two values as if by {$subtract: [{$const: lhs}, {$const: rhs}]}. + * + * If either argument is nullish, returns BSONNULL. + * + * Otherwise, the arguments can be either: + * (numeric, numeric) + * (Date, Date) Returns the time difference in milliseconds. + * (Date, numeric) Returns the date shifted earlier by that many milliseconds. + * + * Otherwise, returns ErrorCodes::TypeMismatch. + */ + static StatusWith<Value> apply(Value lhs, Value rhs); + explicit ExpressionSubtract(ExpressionContext* const expCtx) : ExpressionFixedArity<ExpressionSubtract, 2>(expCtx) { expCtx->sbeCompatible = false; diff --git a/src/mongo/db/pipeline/expression_date_test.cpp b/src/mongo/db/pipeline/expression_date_test.cpp index 160af9ec383..717fdcca651 100644 --- a/src/mongo/db/pipeline/expression_date_test.cpp +++ b/src/mongo/db/pipeline/expression_date_test.cpp @@ -1382,7 +1382,8 @@ TEST_F(ExpressionDateFromStringTest, OnNullEvaluatedLazily) { ASSERT_EQ( "2018-02-14T00:00:00.000Z", dateExp->evaluate(Document{{"date", "2018-02-14"_sd}}, &expCtx->variables).toString()); - ASSERT_THROWS_CODE(dateExp->evaluate({}, &expCtx->variables), AssertionException, 16608); + ASSERT_THROWS_CODE( + dateExp->evaluate({}, &expCtx->variables), AssertionException, ErrorCodes::BadValue); } TEST_F(ExpressionDateFromStringTest, OnErrorEvaluatedLazily) { @@ -1396,8 +1397,9 @@ TEST_F(ExpressionDateFromStringTest, OnErrorEvaluatedLazily) { ASSERT_EQ( "2018-02-14T00:00:00.000Z", dateExp->evaluate(Document{{"date", "2018-02-14"_sd}}, &expCtx->variables).toString()); - ASSERT_THROWS_CODE( - dateExp->evaluate(Document{{"date", 5}}, &expCtx->variables), AssertionException, 16608); + ASSERT_THROWS_CODE(dateExp->evaluate(Document{{"date", 5}}, &expCtx->variables), + AssertionException, + ErrorCodes::BadValue); } } // namespace ExpressionDateFromStringTest diff --git a/src/mongo/db/pipeline/window_function/partition_iterator.cpp b/src/mongo/db/pipeline/window_function/partition_iterator.cpp index 9f2034c80f5..a9a6296b72f 100644 --- a/src/mongo/db/pipeline/window_function/partition_iterator.cpp +++ b/src/mongo/db/pipeline/window_function/partition_iterator.cpp @@ -30,6 +30,7 @@ #include "mongo/platform/basic.h" #include "mongo/db/pipeline/window_function/partition_iterator.h" +#include "mongo/util/visit_helper.h" namespace mongo { @@ -40,6 +41,8 @@ boost::optional<Document> PartitionIterator::operator[](int index) { return boost::none; // Case 0: Outside of lower bound of partition. + // TODO SERVER-53712: when we add expiry, this should tassert that the caller is not asking + // for a document that the caller promised it wouldn't need. if (desired < 0) return boost::none; @@ -109,6 +112,69 @@ PartitionIterator::AdvanceResult PartitionIterator::advance() { } } +namespace { +boost::optional<int> numericBound(WindowBounds::Bound<int> bound) { + return stdx::visit( + visit_helper::Overloaded{ + [](WindowBounds::Unbounded) -> boost::optional<int> { return boost::none; }, + [](WindowBounds::Current) -> boost::optional<int> { return 0; }, + [](int i) -> boost::optional<int> { return i; }, + }, + bound); +} +} // namespace + +boost::optional<std::pair<int, int>> PartitionIterator::getEndpoints(const WindowBounds& bounds) { + // For range-based bounds, we will need to: + // 1. extract the sortBy for (*this)[0] + // 2. step backwards until we cross bounds.lower + // 3. step forwards until we cross bounds.upper + // This means we'll need to pass in sortBy somewhere. + tassert(5423300, + "TODO SERVER-54294: range-based and time-based bounds", + stdx::holds_alternative<WindowBounds::DocumentBased>(bounds.bounds)); + tassert(5423301, "getEndpoints assumes there is a current document", (*this)[0] != boost::none); + auto docBounds = stdx::get<WindowBounds::DocumentBased>(bounds.bounds); + boost::optional<int> lowerBound = numericBound(docBounds.lower); + boost::optional<int> upperBound = numericBound(docBounds.upper); + tassert(5423302, + "Bounds should never be inverted", + !lowerBound || !upperBound || lowerBound <= upperBound); + + // Pull documents into the cache until it contains the whole window. + if (upperBound) { + // For a right-bounded window we only need to pull in documents up to the bound. + (*this)[*upperBound]; + } else { + // For a right-unbounded window we need to pull in the whole partition. operator[] reports + // end of partition by returning boost::none instead of a document. + for (int i = 0; (*this)[i]; ++i) { + } + } + + // Valid offsets into the cache are any 'i' such that '_cache[_currentIndex + i]' is valid. + // We know the cache is nonempty because it contains the current document. + int cacheOffsetMin = -_currentIndex; + int cacheOffsetMax = cacheOffsetMin + _cache.size() - 1; + + // The window can only be empty if the bounds are shifted completely out of the partition. + if (lowerBound && lowerBound > cacheOffsetMax) + return boost::none; + if (upperBound && upperBound < cacheOffsetMin) + return boost::none; + + // Now we know that the window is nonempty, and the cache contains it. + // All we have to do is clamp the bounds to fall within the cache. + auto clamp = [&](int offset) { + // Return the closest offset from the interval '[cacheOffsetMin, cacheOffsetMax]'. + return std::max(cacheOffsetMin, std::min(offset, cacheOffsetMax)); + }; + int lowerOffset = lowerBound ? clamp(*lowerBound) : cacheOffsetMin; + int upperOffset = upperBound ? clamp(*upperBound) : cacheOffsetMax; + + return {{lowerOffset, upperOffset}}; +} + void PartitionIterator::getNextDocument() { tassert(5340103, "Invalid call to PartitionIterator::getNextDocument", diff --git a/src/mongo/db/pipeline/window_function/partition_iterator.h b/src/mongo/db/pipeline/window_function/partition_iterator.h index 3e645d8a768..5a6b34f91bb 100644 --- a/src/mongo/db/pipeline/window_function/partition_iterator.h +++ b/src/mongo/db/pipeline/window_function/partition_iterator.h @@ -31,6 +31,7 @@ #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/window_function/window_bounds.h" namespace mongo { @@ -82,6 +83,25 @@ public: _source = source; } + /** + * Resolve any type of WindowBounds to a concrete pair of indices, '[lower, upper]'. + * + * Both 'lower' and 'upper' are valid offsets, such that '(*this)[lower]' and '(*this)[upper]' + * returns a document. If the window contains one document, then 'lower == upper'. If the + * window contains zero documents, then there are no valid offsets, so we return boost::none. + * (The window can be empty when it is shifted completely past one end of the partition, as in + * [+999, +999] or [-999, -999].) + * + * The offsets can be different after every 'advance()'. Even for simple document-based + * windows, the returned offsets can be different than the user-specified bounds when we + * are close to a partition boundary. For example, at the beginning of a partition, + * 'getEndpoints(DocumentBased{-10, +7})' would be '[0, +7]'. + * + * This method is non-const because it may pull documents into memory up to the end of the + * window. + */ + boost::optional<std::pair<int, int>> getEndpoints(const WindowBounds& bounds); + private: /** * Retrieves the next document from the prior stage and updates the state accordingly. @@ -107,6 +127,7 @@ private: DocumentSource* _source; boost::optional<boost::intrusive_ptr<Expression>> _partitionExpr; std::vector<Document> _cache; + // '_cache[_currentIndex]' is the current document, which '(*this)[0]' returns. int _currentIndex = 0; Value _partitionKey; @@ -121,12 +142,14 @@ private: enum class IteratorState { // Default state, no documents have been pulled into the cache. kNotInitialized, - // Iterating the current partition. + // Iterating the current partition. We don't know where the current partition ends, or + // whether it's the last partition. kIntraPartition, // The first document of the next partition has been retrieved, but the iterator has not // advanced to it yet. kAwaitingAdvanceToNext, - // Similar to the next partition case, except for EOF. + // Similar to the next partition case, except for EOF: we know the current partition is the + // final one, because the underlying iterator has returned EOF. kAwaitingAdvanceToEOF, // The iterator has exhausted the input documents. Any access should be disallowed. kAdvancedToEOF, diff --git a/src/mongo/db/pipeline/window_function/window_function_exec_derivative.cpp b/src/mongo/db/pipeline/window_function/window_function_exec_derivative.cpp new file mode 100644 index 00000000000..c960e40054d --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_exec_derivative.cpp @@ -0,0 +1,111 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/window_function/window_function_exec_derivative.h" + +namespace mongo { + +namespace { +// Convert expected error codes to BSONNULL, but uassert other unexpected codes. +Value orNull(StatusWith<Value> val) { + if (val.getStatus().code() == ErrorCodes::BadValue || + val.getStatus().code() == ErrorCodes::TypeMismatch) + return Value(BSONNULL); + return uassertStatusOK(val); +} +} // namespace + +Value WindowFunctionExecDerivative::getNext() { + auto endpoints = _iter->getEndpoints(_bounds); + if (!endpoints) + return kDefault; + + auto [leftOffset, rightOffset] = *endpoints; + const Document& leftDoc = *(*_iter)[leftOffset]; + const Document& rightDoc = *(*_iter)[rightOffset]; + + // Conceptually, $derivative computes 'rise/run' where 'rise' is dimensionless and 'run' is + // a time. The result has dimension 1/time, which doesn't correspond to any BSON type, so + // 'outputUnit' tells us how to express the result as a dimensionless BSON number. + // + // However, BSON also can't represent a time (duration) directly. BSONType::Date represents + // a point in time, but there is no type that represents an amount of time. Subtracting two + // Date values implicitly converts them to milliseconds. + + // So, when we compute 'rise/run', the answer is expressed in units '1/millisecond'. If an + // 'outputUnit' is specified, we scale the answer by 'millisecond/outputUnit' to + // re-express it in '1/outputUnit'. + Value leftTime = _time->evaluate(leftDoc, &_time->getExpressionContext()->variables); + Value rightTime = _time->evaluate(rightDoc, &_time->getExpressionContext()->variables); + if (_outputUnitMillis) { + // If an outputUnit is specified, we require both endpoints to be dates. We don't + // want to interpret bare numbers as milliseconds, when we don't know what unit they + // really represent. + // + // For example: imagine the '_time' field contains floats representing seconds: then + // 'rise/run' will already be expressed in units of 1/second. If you think "my data is + // seconds" and write 'outputUnit: "second"', and we applied the scale factor of + // 'millisecond/outputUnit', then the final answer would be wrong by a factor of 1000. + if (leftTime.getType() != BSONType::Date || rightTime.getType() != BSONType::Date) { + return kDefault; + } + } else { + // Without outputUnit, we require both time values to be numeric. + if (!leftTime.numeric() || !rightTime.numeric()) + return kDefault; + } + // Now leftTime and rightTime are either both numeric, or both dates. + // $subtract on two dates gives us the difference in milliseconds. + Value run = orNull(ExpressionSubtract::apply(std::move(rightTime), std::move(leftTime))); + if (run.nullish()) + return kDefault; + + Value rise = orNull(ExpressionSubtract::apply( + _position->evaluate(rightDoc, &_position->getExpressionContext()->variables), + _position->evaluate(leftDoc, &_position->getExpressionContext()->variables))); + if (rise.nullish()) + return kDefault; + + Value result = orNull(ExpressionDivide::apply(std::move(rise), std::move(run))); + if (result.nullish()) + return kDefault; + + if (_outputUnitMillis) { + // 'result' has units 1/millisecond; scale by millisecond/outputUnit to express in + // 1/outputUnit. + + // tassert because at this point the result should already be numeric, so if + // ExpressionMultiply returns a non-OK Status then something has gone wrong. + auto statusWithResult = ExpressionMultiply::apply(result, Value(*_outputUnitMillis)); + tassert(statusWithResult); + result = statusWithResult.getValue(); + } + return result; +} +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_exec_derivative.h b/src/mongo/db/pipeline/window_function/window_function_exec_derivative.h new file mode 100644 index 00000000000..1fdf4d31630 --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_exec_derivative.h @@ -0,0 +1,83 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#pragma once + +#include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/window_function/partition_iterator.h" +#include "mongo/db/pipeline/window_function/window_bounds.h" +#include "mongo/db/pipeline/window_function/window_function_exec.h" +#include "mongo/db/query/datetime/date_time_support.h" + +namespace mongo { + +/** + * $derivative computes 'rise/run', by comparing the two endpoints of its window. + * + * 'rise' is the difference in 'position' between the endpoints; 'run' is the difference in 'time'. + * + * We assume the 'time' is provided as an expression, even though the surface syntax uses a + * SortPattern. When the WindowFunctionExpression translates itself to an exec, it can also + * translate the SortPattern to an expression. + */ +class WindowFunctionExecDerivative final : public WindowFunctionExec { +public: + // Default value to use when the window is empty. + static inline const Value kDefault = Value(BSONNULL); + + WindowFunctionExecDerivative(PartitionIterator* iter, + boost::intrusive_ptr<Expression> position, + boost::intrusive_ptr<Expression> time, + WindowBounds bounds, + boost::optional<TimeUnit> outputUnit) + : WindowFunctionExec(iter), + _position(std::move(position)), + _time(std::move(time)), + _bounds(std::move(bounds)), + _outputUnitMillis([&]() -> boost::optional<long long> { + if (!outputUnit) + return boost::none; + + auto status = timeUnitTypicalMilliseconds(*outputUnit); + tassert(status); + return status.getValue(); + }()) {} + + Value getNext() final; + void reset() final {} + +private: + boost::intrusive_ptr<Expression> _position; + boost::intrusive_ptr<Expression> _time; + WindowBounds _bounds; + boost::optional<long long> _outputUnitMillis; +}; + +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_exec_derivative_test.cpp b/src/mongo/db/pipeline/window_function/window_function_exec_derivative_test.cpp new file mode 100644 index 00000000000..548dda10a46 --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_exec_derivative_test.cpp @@ -0,0 +1,315 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/platform/basic.h" + +#include "mongo/db/exec/document_value/document_value_test_util.h" +#include "mongo/db/pipeline/accumulator.h" +#include "mongo/db/pipeline/aggregation_context_fixture.h" +#include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/document_source_mock.h" +#include "mongo/db/pipeline/expression.h" +#include "mongo/db/pipeline/window_function/partition_iterator.h" +#include "mongo/db/pipeline/window_function/window_bounds.h" +#include "mongo/db/pipeline/window_function/window_function_exec_derivative.h" +#include "mongo/unittest/unittest.h" + +namespace mongo { +namespace { + +class WindowFunctionExecDerivativeTest : public AggregationContextFixture { +public: + WindowFunctionExecDerivative createForFieldPath( + std::deque<DocumentSource::GetNextResult> docs, + const std::string& positionPath, + const std::string& timePath, + WindowBounds bounds, + boost::optional<TimeUnit> timeUnit = boost::none) { + _docSource = DocumentSourceMock::createForTest(std::move(docs), getExpCtx()); + _iter = + std::make_unique<PartitionIterator>(getExpCtx().get(), _docSource.get(), boost::none); + + auto position = ExpressionFieldPath::parse( + getExpCtx().get(), positionPath, getExpCtx()->variablesParseState); + auto time = ExpressionFieldPath::parse( + getExpCtx().get(), timePath, getExpCtx()->variablesParseState); + + return WindowFunctionExecDerivative(_iter.get(), + std::move(position), + std::move(time), + std::move(bounds), + std::move(timeUnit)); + } + + auto advanceIterator() { + return _iter->advance(); + } + + Value eval(std::pair<Value, Value> start, + std::pair<Value, Value> end, + boost::optional<TimeUnit> outputUnit = {}) { + const std::deque<DocumentSource::GetNextResult> docs{ + Document{{"t", start.first}, {"y", start.second}}, + Document{{"t", end.first}, {"y", end.second}}}; + auto mgr = createForFieldPath(std::move(docs), + "$y", + "$t", + {WindowBounds::DocumentBased{0, 1}}, + std::move(outputUnit)); + return mgr.getNext(); + } + +private: + boost::intrusive_ptr<DocumentSourceMock> _docSource; + std::unique_ptr<PartitionIterator> _iter; +}; + +TEST_F(WindowFunctionExecDerivativeTest, LookBehind) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 2}}, + Document{{"t", 2}, {"y", 4}}, + Document{{"t", 3}, {"y", 8}}, + }; + + // Look behind 1 document. + auto mgr = + createForFieldPath(std::move(docs), "$y", "$t", {WindowBounds::DocumentBased{-1, 0}}); + + // Initially, the window only has one document, so we can't compute a derivative. + ASSERT_VALUE_EQ(Value(BSONNULL), mgr.getNext()); + advanceIterator(); + // Now since t changes by 1 every time, answer should be just dy. + ASSERT_VALUE_EQ(Value(1), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(2), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(4), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, LookAhead) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 2}}, + Document{{"t", 2}, {"y", 4}}, + Document{{"t", 3}, {"y", 8}}, + }; + + // Look ahead 1 document. + auto mgr = + createForFieldPath(std::move(docs), "$y", "$t", {WindowBounds::DocumentBased{0, +1}}); + // Now the first document's window has two documents. + ASSERT_VALUE_EQ(Value(1), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(2), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(4), mgr.getNext()); + advanceIterator(); + // At the end of the partition we only have one document. + ASSERT_VALUE_EQ(Value(BSONNULL), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, LookAround) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 2}}, + Document{{"t", 2}, {"y", 4}}, + Document{{"t", 3}, {"y", 8}}, + }; + + // Look around 1 document (look 1 behind and 1 ahead). + // This case is interesting because at the partition boundaries, we can still define a + // derivative, but the window is smaller. + auto mgr = + createForFieldPath(std::move(docs), "$y", "$t", {WindowBounds::DocumentBased{-1, +1}}); + // The first document sees itself and the 1 document following. + // Time changes by 1 and y changes from 1 to 2. + ASSERT_VALUE_EQ(Value(1), mgr.getNext()); + advanceIterator(); + // The second document sees the 1 previous and the 1 following. + // Times changes by 2, and y changes from 1 to 4. + ASSERT_VALUE_EQ(Value(3.0 / 2), mgr.getNext()); + advanceIterator(); + // Next, y goes from 2 to 8. + ASSERT_VALUE_EQ(Value(6.0 / 2), mgr.getNext()); + advanceIterator(); + // Finally, the window shrinks back down to 2 documents. + // y goes from 4 to 8. + ASSERT_VALUE_EQ(Value(4), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, UnboundedBefore) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 10}}, + Document{{"t", 2}, {"y", 100}}, + }; + + auto mgr = createForFieldPath( + std::move(docs), + "$y", + "$t", + { + WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Current{}}, + }); + // t is 0 to 0. + ASSERT_VALUE_EQ(Value(BSONNULL), mgr.getNext()); + advanceIterator(); + // t is 0 to 1. + ASSERT_VALUE_EQ(Value(9.0 / 1), mgr.getNext()); + advanceIterator(); + // t is 0 to 2. + ASSERT_VALUE_EQ(Value(99.0 / 2), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, UnboundedAfter) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 10}}, + Document{{"t", 2}, {"y", 100}}, + }; + + auto mgr = createForFieldPath( + std::move(docs), + "$y", + "$t", + { + WindowBounds::DocumentBased{WindowBounds::Current{}, WindowBounds::Unbounded{}}, + }); + // t is 0 to 2. + ASSERT_VALUE_EQ(Value(99.0 / 2), mgr.getNext()); + advanceIterator(); + // t is 1 to 2. + ASSERT_VALUE_EQ(Value(90.0 / 1), mgr.getNext()); + advanceIterator(); + // t is 2 to 2. + ASSERT_VALUE_EQ(Value(BSONNULL), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, Unbounded) { + const auto docs = std::deque<DocumentSource::GetNextResult>{ + Document{{"t", 0}, {"y", 1}}, + Document{{"t", 1}, {"y", 10}}, + Document{{"t", 2}, {"y", 100}}, + }; + + auto mgr = createForFieldPath( + std::move(docs), + "$y", + "$t", + { + WindowBounds::DocumentBased{WindowBounds::Unbounded{}, WindowBounds::Unbounded{}}, + }); + // t is 0 to 2, in all 3 cases. + ASSERT_VALUE_EQ(Value(99.0 / 2), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(99.0 / 2), mgr.getNext()); + advanceIterator(); + ASSERT_VALUE_EQ(Value(99.0 / 2), mgr.getNext()); +} + +TEST_F(WindowFunctionExecDerivativeTest, NonNumbers) { + auto t0 = Value{0}; + auto t1 = Value{1}; + auto y0 = Value{5}; + auto y1 = Value{6}; + auto bad = Value{"a string"_sd}; + // If any one value (position or time) is an invalid type, such as a string, the output is null. + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({bad, y0}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, bad}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {bad, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {t1, bad})); + + // If any one value is null, the output is null. + bad = Value{BSONNULL}; + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({bad, y0}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, bad}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {bad, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {t1, bad})); + + // If any one value is missing, the output is null. + bad = Value{}; + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({bad, y0}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, bad}, {t1, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {bad, y1})); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {t1, bad})); +} + +TEST_F(WindowFunctionExecDerivativeTest, DatesAreNonNumbers) { + // When no outputUnit is specified, dates are treated as any other non-numeric type. + + // 'y' increases by 1, over 8ms. + auto t0 = Value{Date_t::fromMillisSinceEpoch(0)}; + auto t1 = Value{Date_t::fromMillisSinceEpoch(8)}; + auto y0 = Value{5}; + auto y1 = Value{6}; + // Each ms, 'y' increases by 1/8. + // This is exact despite floating point, because 8 is a power of 2. + ASSERT_VALUE_EQ(Value(BSONNULL), eval({t0, y0}, {t1, y1})); +} + +TEST_F(WindowFunctionExecDerivativeTest, OutputUnit) { + // 'y' increases by 1, over 8ms. + auto t0 = Value{Date_t::fromMillisSinceEpoch(0)}; + auto t1 = Value{Date_t::fromMillisSinceEpoch(8)}; + auto y0 = Value{5}; + auto y1 = Value{6}; + + // Calculate the derivative, expressed in the given TimeUnit. + auto calc = [&](TimeUnit unit) -> Value { return eval({t0, y0}, {t1, y1}, unit); }; + // Each ms, 'y' increased by 1/8. + // (This should be exact, despite floating point, because 8 is a power of 2.) + ASSERT_VALUE_EQ(calc(TimeUnit::millisecond), Value{1.0 / 8}); + + // Each second, 'y' increases by 1/8 1000 times (once per ms). + ASSERT_VALUE_EQ(Value{(1.0 / 8) * 1000}, calc(TimeUnit::second)); + // And so on, with larger units. + ASSERT_VALUE_EQ(Value{(1.0 / 8) * 1000 * 60}, calc(TimeUnit::minute)); + ASSERT_VALUE_EQ(Value{(1.0 / 8) * 1000 * 60 * 60}, calc(TimeUnit::hour)); + ASSERT_VALUE_EQ(Value{(1.0 / 8) * 1000 * 60 * 60 * 24}, calc(TimeUnit::day)); + ASSERT_VALUE_EQ(Value{(1.0 / 8) * 1000 * 60 * 60 * 24 * 7}, calc(TimeUnit::week)); +} + +TEST_F(WindowFunctionExecDerivativeTest, OutputUnitNonDate) { + // outputUnit requires the time input to be a datetime: non-datetimes are treated like any other + // invalid type ($derivate returns null). + + auto t0 = Value{Date_t::fromMillisSinceEpoch(0)}; + auto t1 = Value{Date_t::fromMillisSinceEpoch(1000)}; + auto y0 = Value{0}; + auto y1 = Value{0}; + auto bad = Value{500}; + + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({bad, y0}, {t1, y1}, TimeUnit::millisecond)); + ASSERT_VALUE_EQ(Value{BSONNULL}, eval({t0, y0}, {bad, y1}, TimeUnit::millisecond)); +} + + +} // namespace +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp b/src/mongo/db/pipeline/window_function/window_function_exec_non_removable_test.cpp index 3f36d25123d..3f36d25123d 100644 --- a/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp +++ b/src/mongo/db/pipeline/window_function/window_function_exec_non_removable_test.cpp diff --git a/src/mongo/db/query/datetime/date_time_support.cpp b/src/mongo/db/query/datetime/date_time_support.cpp index 096504a9e46..80778737520 100644 --- a/src/mongo/db/query/datetime/date_time_support.cpp +++ b/src/mongo/db/query/datetime/date_time_support.cpp @@ -590,6 +590,7 @@ auto const kDaysInNonLeapYear = 365LL; auto const kHoursPerDay = 24LL; auto const kMinutesPerHour = 60LL; auto const kSecondsPerMinute = 60LL; +auto const kMillisecondsPerSecond = 1000LL; auto const kDaysPerWeek = 7; auto const kQuartersPerYear = 4LL; auto const kQuarterLengthInMonths = 3LL; @@ -977,4 +978,35 @@ Date_t dateAdd(Date_t date, TimeUnit unit, long long amount, const TimeZone& tim timelib_time_dtor(newTime); return returnDate; } + +StatusWith<long long> timeUnitTypicalMilliseconds(TimeUnit unit) { + auto constexpr millisecond = 1; + auto constexpr second = millisecond * kMillisecondsPerSecond; + auto constexpr minute = second * kSecondsPerMinute; + auto constexpr hour = minute * kMinutesPerHour; + auto constexpr day = hour * kHoursPerDay; + auto constexpr week = day * kDaysPerWeek; + + switch (unit) { + case TimeUnit::millisecond: + return millisecond; + case TimeUnit::second: + return second; + case TimeUnit::minute: + return minute; + case TimeUnit::hour: + return hour; + case TimeUnit::day: + return day; + case TimeUnit::week: + return week; + case TimeUnit::month: + case TimeUnit::quarter: + case TimeUnit::year: + return Status(ErrorCodes::BadValue, + str::stream() << "TimeUnit is too big: " << serializeTimeUnit(unit)); + } + MONGO_UNREACHABLE_TASSERT(5423303); +} + } // namespace mongo diff --git a/src/mongo/db/query/datetime/date_time_support.h b/src/mongo/db/query/datetime/date_time_support.h index b4c1fca1007..7012ca2dc06 100644 --- a/src/mongo/db/query/datetime/date_time_support.h +++ b/src/mongo/db/query/datetime/date_time_support.h @@ -599,4 +599,17 @@ long long dateDiff(Date_t startDate, * timezone - the timezone in which the start date is interpreted */ Date_t dateAdd(Date_t date, TimeUnit unit, long long amount, const TimeZone& timezone); + +/** + * Convert (approximately) a TimeUnit to a number of milliseconds. + * + * The answer is approximate because TimeUnit represents an amount of calendar time: + * for example, some calendar days are 23 or 25 hours long due to daylight savings time. + * This function assumes everything is "typical": days are 24 hours, minutes are 60 seconds. + * + * Large time units, 'month' or longer, are so variable that we don't try to pick a value: we + * return a non-OK Status. + */ +StatusWith<long long> timeUnitTypicalMilliseconds(TimeUnit unit); + } // namespace mongo |