diff options
Diffstat (limited to 'src/mongo/db/pipeline/window_function/window_function_expression.h')
-rw-r--r-- | src/mongo/db/pipeline/window_function/window_function_expression.h | 218 |
1 files changed, 166 insertions, 52 deletions
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h index bb1ef7d7704..a34bc68c0d5 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.h +++ b/src/mongo/db/pipeline/window_function/window_function_expression.h @@ -31,6 +31,7 @@ #include "mongo/base/initializer.h" #include "mongo/db/pipeline/accumulator.h" +#include "mongo/db/pipeline/accumulator_for_window_functions.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_set_window_fields_gen.h" #include "mongo/db/pipeline/window_function/window_bounds.h" @@ -397,17 +398,101 @@ protected: boost::optional<Decimal128> _alpha; }; -class ExpressionDerivative : public Expression { +class ExpressionWithOutputUnit : public Expression { public: static constexpr StringData kArgInput = "input"_sd; static constexpr StringData kArgOutputUnit = "outputUnit"_sd; + ExpressionWithOutputUnit(ExpressionContext* expCtx, + std::string accumulatorName, + boost::intrusive_ptr<::mongo::Expression> input, + WindowBounds bounds, + boost::optional<TimeUnit> outputUnit) + : Expression(expCtx, accumulatorName, std::move(input), std::move(bounds)), + _outputUnit(outputUnit) {} + + boost::optional<TimeUnit> outputUnit() const { + return _outputUnit; + } + + Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final { + MutableDocument result; + result[_accumulatorName][kArgInput] = _input->serialize(static_cast<bool>(explain)); + if (_outputUnit) { + result[_accumulatorName][kArgOutputUnit] = Value(serializeTimeUnit(*_outputUnit)); + } + + MutableDocument windowField; + _bounds.serialize(windowField); + result[kWindowArg] = windowField.freezeToValue(); + return result.freezeToValue(); + } + +protected: + static boost::optional<TimeUnit> parseOutputUnit(const BSONElement& arg) { + boost::optional<TimeUnit> outputUnit; + { + uassert(ErrorCodes::FailedToParse, + str::stream() << kArgOutputUnit << "' must be a string, but got " << arg.type(), + arg.type() == String); + outputUnit = parseTimeUnit(arg.valueStringData()); + switch (*outputUnit) { + // These larger time units vary so much, it doesn't make sense to define a + // fixed conversion from milliseconds. (See 'timeUnitTypicalMilliseconds'.) + case TimeUnit::year: + case TimeUnit::quarter: + case TimeUnit::month: + uasserted(5490704, "outputUnit must be 'week' or smaller"); + // Only these time units are allowed. + case TimeUnit::week: + case TimeUnit::day: + case TimeUnit::hour: + case TimeUnit::minute: + case TimeUnit::second: + case TimeUnit::millisecond: + break; + } + } + return outputUnit; + } + + static void validateSortBy(const boost::optional<SortPattern>& sortBy, + const std::string& accumulatorName) { + uassert(ErrorCodes::FailedToParse, + str::stream() << accumulatorName << " requires a sortBy", + sortBy); + uassert(ErrorCodes::FailedToParse, + str::stream() << accumulatorName << " requires a non-compound sortBy", + sortBy->size() == 1); + uassert(ErrorCodes::FailedToParse, + str::stream() << accumulatorName << " requires a non-expression sortBy", + !sortBy->begin()->expression); + uassert(ErrorCodes::FailedToParse, + str::stream() << accumulatorName << " requires an ascending sortBy", + sortBy->begin()->isAscending); + } + + boost::optional<long long> convertTimeUnitToMillis(boost::optional<TimeUnit> outputUnit) const { + if (!outputUnit) + return boost::none; + + auto status = timeUnitTypicalMilliseconds(*outputUnit); + tassert(status); + + return status.getValue(); + } + + boost::optional<TimeUnit> _outputUnit; +}; + +class ExpressionDerivative : public ExpressionWithOutputUnit { +public: ExpressionDerivative(ExpressionContext* expCtx, boost::intrusive_ptr<::mongo::Expression> input, WindowBounds bounds, boost::optional<TimeUnit> outputUnit) - : Expression(expCtx, "$derivative", std::move(input), std::move(bounds)), - _outputUnit(outputUnit) {} + : ExpressionWithOutputUnit( + expCtx, "$derivative", std::move(input), std::move(bounds), outputUnit) {} static boost::intrusive_ptr<Expression> parse(BSONObj obj, const boost::optional<SortPattern>& sortBy, @@ -419,17 +504,7 @@ public: // } // window: {...} // optional // } - - uassert(ErrorCodes::FailedToParse, "$derivative requires a sortBy", sortBy); - uassert(ErrorCodes::FailedToParse, - "$derivative requires a non-compound sortBy", - sortBy->size() == 1); - uassert(ErrorCodes::FailedToParse, - "$derivative requires a non-expression sortBy", - !sortBy->begin()->expression); - uassert(ErrorCodes::FailedToParse, - "$derivative requires an ascending sortBy", - sortBy->begin()->isAscending); + validateSortBy(sortBy, "$derivative"); boost::optional<WindowBounds> bounds; BSONElement derivativeArgs; @@ -462,27 +537,7 @@ public: if (argName == kArgInput) { input = ::mongo::Expression::parseOperand(expCtx, arg, expCtx->variablesParseState); } else if (argName == kArgOutputUnit) { - uassert(ErrorCodes::FailedToParse, - str::stream() << "$derivative '" << kArgOutputUnit - << "' must be a string, but got " << arg.type(), - arg.type() == String); - outputUnit = parseTimeUnit(arg.valueStringData()); - switch (*outputUnit) { - // These larger time units vary so much, it doesn't make sense to define a - // fixed conversion from milliseconds. (See 'timeUnitTypicalMilliseconds'.) - case TimeUnit::year: - case TimeUnit::quarter: - case TimeUnit::month: - uasserted(5490704, "$derivative outputUnit must be 'week' or smaller"); - // Only these time units are allowed. - case TimeUnit::week: - case TimeUnit::day: - case TimeUnit::hour: - case TimeUnit::minute: - case TimeUnit::second: - case TimeUnit::millisecond: - break; - } + outputUnit = parseOutputUnit(arg); } else { uasserted(ErrorCodes::FailedToParse, str::stream() << "$derivative got unexpected argument: " << argName); @@ -500,19 +555,6 @@ public: expCtx, std::move(input), std::move(*bounds), outputUnit); } - Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final { - MutableDocument result; - result[_accumulatorName][kArgInput] = _input->serialize(static_cast<bool>(explain)); - if (_outputUnit) { - result[_accumulatorName][kArgOutputUnit] = Value(serializeTimeUnit(*_outputUnit)); - } - - MutableDocument windowField; - _bounds.serialize(windowField); - result[kWindowArg] = windowField.freezeToValue(); - return result.freezeToValue(); - } - boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final { MONGO_UNREACHABLE_TASSERT(5490701); } @@ -520,13 +562,85 @@ public: std::unique_ptr<WindowFunctionState> buildRemovable() const final { MONGO_UNREACHABLE_TASSERT(5490702); } +}; - auto outputUnit() const { - return _outputUnit; +class ExpressionIntegral : public ExpressionWithOutputUnit { +public: + ExpressionIntegral(ExpressionContext* expCtx, + boost::intrusive_ptr<::mongo::Expression> input, + WindowBounds bounds, + boost::optional<TimeUnit> outputUnit) + : ExpressionWithOutputUnit( + expCtx, "$integral", std::move(input), std::move(bounds), outputUnit) {} + + static boost::intrusive_ptr<Expression> parse(BSONObj obj, + const boost::optional<SortPattern>& sortBy, + ExpressionContext* expCtx) { + // { + // $integral: { + // input: <expr>, + // outputUnit: <string>, // optional + // } + // window: {...} // optional + // } + // + validateSortBy(sortBy, "$integral"); + + boost::optional<WindowBounds> bounds = boost::none; + BSONElement integralArgs; + for (const auto& arg : obj) { + auto argName = arg.fieldNameStringData(); + if (argName == kWindowArg) { + uassert(ErrorCodes::FailedToParse, + "'window' field must be an object", + obj[kWindowArg].type() == BSONType::Object); + uassert(ErrorCodes::FailedToParse, + "There can be only one 'window' field for $integral", + bounds == boost::none); + bounds = WindowBounds::parse(arg.embeddedObject(), sortBy, expCtx); + } else if (argName == "$integral"_sd) { + integralArgs = arg; + } else { + uasserted(ErrorCodes::FailedToParse, + str::stream() << "$integral got unexpected argument: " << argName); + } + } + tassert( + 5558801, "$integral parser called on object with no $integral key", integralArgs.ok()); + uassert(ErrorCodes::FailedToParse, + str::stream() << "$integral expects an object, but got a " << integralArgs.type() + << ": " << integralArgs, + integralArgs.type() == BSONType::Object); + + boost::intrusive_ptr<::mongo::Expression> input; + boost::optional<TimeUnit> outputUnit = boost::none; + for (const auto& arg : integralArgs.Obj()) { + auto argName = arg.fieldNameStringData(); + if (argName == kArgInput) { + input = ::mongo::Expression::parseOperand(expCtx, arg, expCtx->variablesParseState); + } else if (argName == kArgOutputUnit) { + uassert(ErrorCodes::FailedToParse, + "There can be only one 'outputUnit' field for $integral", + outputUnit == boost::none); + outputUnit = parseOutputUnit(arg); + } else { + uasserted(ErrorCodes::FailedToParse, + str::stream() << "$integral got unexpected argument: " << argName); + } + } + uassert(ErrorCodes::FailedToParse, "$integral requires an 'input' expression", input); + + return make_intrusive<ExpressionIntegral>( + expCtx, std::move(input), bounds ? *bounds : WindowBounds(), outputUnit); } -private: - boost::optional<TimeUnit> _outputUnit; + boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final { + return AccumulatorIntegral::create(_expCtx, convertTimeUnitToMillis(_outputUnit)); + } + + std::unique_ptr<WindowFunctionState> buildRemovable() const final { + return WindowFunctionIntegral::create(_expCtx, convertTimeUnitToMillis(_outputUnit)); + } }; } // namespace mongo::window_function |