summaryrefslogtreecommitdiff
path: root/src/mongo/db/pipeline/window_function/window_function_expression.h
diff options
context:
space:
mode:
Diffstat (limited to 'src/mongo/db/pipeline/window_function/window_function_expression.h')
-rw-r--r--src/mongo/db/pipeline/window_function/window_function_expression.h218
1 files changed, 166 insertions, 52 deletions
diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h
index bb1ef7d7704..a34bc68c0d5 100644
--- a/src/mongo/db/pipeline/window_function/window_function_expression.h
+++ b/src/mongo/db/pipeline/window_function/window_function_expression.h
@@ -31,6 +31,7 @@
#include "mongo/base/initializer.h"
#include "mongo/db/pipeline/accumulator.h"
+#include "mongo/db/pipeline/accumulator_for_window_functions.h"
#include "mongo/db/pipeline/document_source.h"
#include "mongo/db/pipeline/document_source_set_window_fields_gen.h"
#include "mongo/db/pipeline/window_function/window_bounds.h"
@@ -397,17 +398,101 @@ protected:
boost::optional<Decimal128> _alpha;
};
-class ExpressionDerivative : public Expression {
+class ExpressionWithOutputUnit : public Expression {
public:
static constexpr StringData kArgInput = "input"_sd;
static constexpr StringData kArgOutputUnit = "outputUnit"_sd;
+ ExpressionWithOutputUnit(ExpressionContext* expCtx,
+ std::string accumulatorName,
+ boost::intrusive_ptr<::mongo::Expression> input,
+ WindowBounds bounds,
+ boost::optional<TimeUnit> outputUnit)
+ : Expression(expCtx, accumulatorName, std::move(input), std::move(bounds)),
+ _outputUnit(outputUnit) {}
+
+ boost::optional<TimeUnit> outputUnit() const {
+ return _outputUnit;
+ }
+
+ Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final {
+ MutableDocument result;
+ result[_accumulatorName][kArgInput] = _input->serialize(static_cast<bool>(explain));
+ if (_outputUnit) {
+ result[_accumulatorName][kArgOutputUnit] = Value(serializeTimeUnit(*_outputUnit));
+ }
+
+ MutableDocument windowField;
+ _bounds.serialize(windowField);
+ result[kWindowArg] = windowField.freezeToValue();
+ return result.freezeToValue();
+ }
+
+protected:
+ static boost::optional<TimeUnit> parseOutputUnit(const BSONElement& arg) {
+ boost::optional<TimeUnit> outputUnit;
+ {
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << kArgOutputUnit << "' must be a string, but got " << arg.type(),
+ arg.type() == String);
+ outputUnit = parseTimeUnit(arg.valueStringData());
+ switch (*outputUnit) {
+ // These larger time units vary so much, it doesn't make sense to define a
+ // fixed conversion from milliseconds. (See 'timeUnitTypicalMilliseconds'.)
+ case TimeUnit::year:
+ case TimeUnit::quarter:
+ case TimeUnit::month:
+ uasserted(5490704, "outputUnit must be 'week' or smaller");
+ // Only these time units are allowed.
+ case TimeUnit::week:
+ case TimeUnit::day:
+ case TimeUnit::hour:
+ case TimeUnit::minute:
+ case TimeUnit::second:
+ case TimeUnit::millisecond:
+ break;
+ }
+ }
+ return outputUnit;
+ }
+
+ static void validateSortBy(const boost::optional<SortPattern>& sortBy,
+ const std::string& accumulatorName) {
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << accumulatorName << " requires a sortBy",
+ sortBy);
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << accumulatorName << " requires a non-compound sortBy",
+ sortBy->size() == 1);
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << accumulatorName << " requires a non-expression sortBy",
+ !sortBy->begin()->expression);
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << accumulatorName << " requires an ascending sortBy",
+ sortBy->begin()->isAscending);
+ }
+
+ boost::optional<long long> convertTimeUnitToMillis(boost::optional<TimeUnit> outputUnit) const {
+ if (!outputUnit)
+ return boost::none;
+
+ auto status = timeUnitTypicalMilliseconds(*outputUnit);
+ tassert(status);
+
+ return status.getValue();
+ }
+
+ boost::optional<TimeUnit> _outputUnit;
+};
+
+class ExpressionDerivative : public ExpressionWithOutputUnit {
+public:
ExpressionDerivative(ExpressionContext* expCtx,
boost::intrusive_ptr<::mongo::Expression> input,
WindowBounds bounds,
boost::optional<TimeUnit> outputUnit)
- : Expression(expCtx, "$derivative", std::move(input), std::move(bounds)),
- _outputUnit(outputUnit) {}
+ : ExpressionWithOutputUnit(
+ expCtx, "$derivative", std::move(input), std::move(bounds), outputUnit) {}
static boost::intrusive_ptr<Expression> parse(BSONObj obj,
const boost::optional<SortPattern>& sortBy,
@@ -419,17 +504,7 @@ public:
// }
// window: {...} // optional
// }
-
- uassert(ErrorCodes::FailedToParse, "$derivative requires a sortBy", sortBy);
- uassert(ErrorCodes::FailedToParse,
- "$derivative requires a non-compound sortBy",
- sortBy->size() == 1);
- uassert(ErrorCodes::FailedToParse,
- "$derivative requires a non-expression sortBy",
- !sortBy->begin()->expression);
- uassert(ErrorCodes::FailedToParse,
- "$derivative requires an ascending sortBy",
- sortBy->begin()->isAscending);
+ validateSortBy(sortBy, "$derivative");
boost::optional<WindowBounds> bounds;
BSONElement derivativeArgs;
@@ -462,27 +537,7 @@ public:
if (argName == kArgInput) {
input = ::mongo::Expression::parseOperand(expCtx, arg, expCtx->variablesParseState);
} else if (argName == kArgOutputUnit) {
- uassert(ErrorCodes::FailedToParse,
- str::stream() << "$derivative '" << kArgOutputUnit
- << "' must be a string, but got " << arg.type(),
- arg.type() == String);
- outputUnit = parseTimeUnit(arg.valueStringData());
- switch (*outputUnit) {
- // These larger time units vary so much, it doesn't make sense to define a
- // fixed conversion from milliseconds. (See 'timeUnitTypicalMilliseconds'.)
- case TimeUnit::year:
- case TimeUnit::quarter:
- case TimeUnit::month:
- uasserted(5490704, "$derivative outputUnit must be 'week' or smaller");
- // Only these time units are allowed.
- case TimeUnit::week:
- case TimeUnit::day:
- case TimeUnit::hour:
- case TimeUnit::minute:
- case TimeUnit::second:
- case TimeUnit::millisecond:
- break;
- }
+ outputUnit = parseOutputUnit(arg);
} else {
uasserted(ErrorCodes::FailedToParse,
str::stream() << "$derivative got unexpected argument: " << argName);
@@ -500,19 +555,6 @@ public:
expCtx, std::move(input), std::move(*bounds), outputUnit);
}
- Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final {
- MutableDocument result;
- result[_accumulatorName][kArgInput] = _input->serialize(static_cast<bool>(explain));
- if (_outputUnit) {
- result[_accumulatorName][kArgOutputUnit] = Value(serializeTimeUnit(*_outputUnit));
- }
-
- MutableDocument windowField;
- _bounds.serialize(windowField);
- result[kWindowArg] = windowField.freezeToValue();
- return result.freezeToValue();
- }
-
boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final {
MONGO_UNREACHABLE_TASSERT(5490701);
}
@@ -520,13 +562,85 @@ public:
std::unique_ptr<WindowFunctionState> buildRemovable() const final {
MONGO_UNREACHABLE_TASSERT(5490702);
}
+};
- auto outputUnit() const {
- return _outputUnit;
+class ExpressionIntegral : public ExpressionWithOutputUnit {
+public:
+ ExpressionIntegral(ExpressionContext* expCtx,
+ boost::intrusive_ptr<::mongo::Expression> input,
+ WindowBounds bounds,
+ boost::optional<TimeUnit> outputUnit)
+ : ExpressionWithOutputUnit(
+ expCtx, "$integral", std::move(input), std::move(bounds), outputUnit) {}
+
+ static boost::intrusive_ptr<Expression> parse(BSONObj obj,
+ const boost::optional<SortPattern>& sortBy,
+ ExpressionContext* expCtx) {
+ // {
+ // $integral: {
+ // input: <expr>,
+ // outputUnit: <string>, // optional
+ // }
+ // window: {...} // optional
+ // }
+ //
+ validateSortBy(sortBy, "$integral");
+
+ boost::optional<WindowBounds> bounds = boost::none;
+ BSONElement integralArgs;
+ for (const auto& arg : obj) {
+ auto argName = arg.fieldNameStringData();
+ if (argName == kWindowArg) {
+ uassert(ErrorCodes::FailedToParse,
+ "'window' field must be an object",
+ obj[kWindowArg].type() == BSONType::Object);
+ uassert(ErrorCodes::FailedToParse,
+ "There can be only one 'window' field for $integral",
+ bounds == boost::none);
+ bounds = WindowBounds::parse(arg.embeddedObject(), sortBy, expCtx);
+ } else if (argName == "$integral"_sd) {
+ integralArgs = arg;
+ } else {
+ uasserted(ErrorCodes::FailedToParse,
+ str::stream() << "$integral got unexpected argument: " << argName);
+ }
+ }
+ tassert(
+ 5558801, "$integral parser called on object with no $integral key", integralArgs.ok());
+ uassert(ErrorCodes::FailedToParse,
+ str::stream() << "$integral expects an object, but got a " << integralArgs.type()
+ << ": " << integralArgs,
+ integralArgs.type() == BSONType::Object);
+
+ boost::intrusive_ptr<::mongo::Expression> input;
+ boost::optional<TimeUnit> outputUnit = boost::none;
+ for (const auto& arg : integralArgs.Obj()) {
+ auto argName = arg.fieldNameStringData();
+ if (argName == kArgInput) {
+ input = ::mongo::Expression::parseOperand(expCtx, arg, expCtx->variablesParseState);
+ } else if (argName == kArgOutputUnit) {
+ uassert(ErrorCodes::FailedToParse,
+ "There can be only one 'outputUnit' field for $integral",
+ outputUnit == boost::none);
+ outputUnit = parseOutputUnit(arg);
+ } else {
+ uasserted(ErrorCodes::FailedToParse,
+ str::stream() << "$integral got unexpected argument: " << argName);
+ }
+ }
+ uassert(ErrorCodes::FailedToParse, "$integral requires an 'input' expression", input);
+
+ return make_intrusive<ExpressionIntegral>(
+ expCtx, std::move(input), bounds ? *bounds : WindowBounds(), outputUnit);
}
-private:
- boost::optional<TimeUnit> _outputUnit;
+ boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const final {
+ return AccumulatorIntegral::create(_expCtx, convertTimeUnitToMillis(_outputUnit));
+ }
+
+ std::unique_ptr<WindowFunctionState> buildRemovable() const final {
+ return WindowFunctionIntegral::create(_expCtx, convertTimeUnitToMillis(_outputUnit));
+ }
};
} // namespace mongo::window_function