diff options
author | Nick Zolnierz <nicholas.zolnierz@mongodb.com> | 2021-02-23 08:02:58 -0500 |
---|---|---|
committer | Evergreen Agent <no-reply@evergreen.mongodb.com> | 2021-02-23 14:00:45 +0000 |
commit | e8134e5da39f19b661efb488bb392abbbf436c74 (patch) | |
tree | 9e96b80b61c87c01dc8c3a0565630830d99ce705 /src/mongo | |
parent | cabf5479aaf8cfe8b2bd420fbabff3ee0fb6e229 (diff) | |
download | mongo-e8134e5da39f19b661efb488bb392abbbf436c74.tar.gz |
SERVER-54049 Add translation phase for accumulator-style window functions
Diffstat (limited to 'src/mongo')
16 files changed, 187 insertions, 94 deletions
diff --git a/src/mongo/db/pipeline/SConscript b/src/mongo/db/pipeline/SConscript index 310479ecb3b..57911834dcf 100644 --- a/src/mongo/db/pipeline/SConscript +++ b/src/mongo/db/pipeline/SConscript @@ -262,6 +262,7 @@ pipelineEnv.Library( 'skip_and_limit.cpp', 'tee_buffer.cpp', 'window_function/partition_iterator.cpp', + 'window_function/window_function_exec.cpp', 'window_function/window_function_exec_removable_document.cpp', ], LIBDEPS=[ diff --git a/src/mongo/db/pipeline/accumulator_add_to_set.cpp b/src/mongo/db/pipeline/accumulator_add_to_set.cpp index cca39c44294..0608b006dd2 100644 --- a/src/mongo/db/pipeline/accumulator_add_to_set.cpp +++ b/src/mongo/db/pipeline/accumulator_add_to_set.cpp @@ -43,7 +43,8 @@ using boost::intrusive_ptr; using std::vector; REGISTER_ACCUMULATOR(addToSet, genericParseSingleExpressionAccumulator<AccumulatorAddToSet>); -REGISTER_WINDOW_FUNCTION(addToSet, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(addToSet, + window_function::ExpressionFromAccumulator<AccumulatorAddToSet>::parse); const char* AccumulatorAddToSet::getOpName() const { return "$addToSet"; diff --git a/src/mongo/db/pipeline/accumulator_avg.cpp b/src/mongo/db/pipeline/accumulator_avg.cpp index fff8cbc2e2b..a0d66419749 100644 --- a/src/mongo/db/pipeline/accumulator_avg.cpp +++ b/src/mongo/db/pipeline/accumulator_avg.cpp @@ -45,7 +45,7 @@ using boost::intrusive_ptr; REGISTER_ACCUMULATOR(avg, genericParseSingleExpressionAccumulator<AccumulatorAvg>); REGISTER_EXPRESSION(avg, ExpressionFromAccumulator<AccumulatorAvg>::parse); -REGISTER_WINDOW_FUNCTION(avg, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(avg, window_function::ExpressionFromAccumulator<AccumulatorAvg>::parse); const char* AccumulatorAvg::getOpName() const { return "$avg"; diff --git a/src/mongo/db/pipeline/accumulator_min_max.cpp b/src/mongo/db/pipeline/accumulator_min_max.cpp index 3326b9b8450..7bb38340bd7 100644 --- a/src/mongo/db/pipeline/accumulator_min_max.cpp +++ b/src/mongo/db/pipeline/accumulator_min_max.cpp @@ -44,8 +44,8 @@ REGISTER_ACCUMULATOR(max, genericParseSingleExpressionAccumulator<AccumulatorMax REGISTER_ACCUMULATOR(min, genericParseSingleExpressionAccumulator<AccumulatorMin>); REGISTER_EXPRESSION(max, ExpressionFromAccumulator<AccumulatorMax>::parse); REGISTER_EXPRESSION(min, ExpressionFromAccumulator<AccumulatorMin>::parse); -REGISTER_WINDOW_FUNCTION(max, window_function::ExpressionFromAccumulator::parse); -REGISTER_WINDOW_FUNCTION(min, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(max, window_function::ExpressionFromAccumulator<AccumulatorMax>::parse); +REGISTER_WINDOW_FUNCTION(min, window_function::ExpressionFromAccumulator<AccumulatorMin>::parse); const char* AccumulatorMinMax::getOpName() const { if (_sense == Sense::kMin) diff --git a/src/mongo/db/pipeline/accumulator_push.cpp b/src/mongo/db/pipeline/accumulator_push.cpp index 1c92ba2af6b..38fa0a8d03f 100644 --- a/src/mongo/db/pipeline/accumulator_push.cpp +++ b/src/mongo/db/pipeline/accumulator_push.cpp @@ -43,7 +43,7 @@ using boost::intrusive_ptr; using std::vector; REGISTER_ACCUMULATOR(push, genericParseSingleExpressionAccumulator<AccumulatorPush>); -REGISTER_WINDOW_FUNCTION(push, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(push, window_function::ExpressionFromAccumulator<AccumulatorPush>::parse); const char* AccumulatorPush::getOpName() const { return "$push"; diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp index 12e07b55775..2a708bbbf40 100644 --- a/src/mongo/db/pipeline/accumulator_std_dev.cpp +++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp @@ -45,8 +45,10 @@ REGISTER_ACCUMULATOR(stdDevPop, genericParseSingleExpressionAccumulator<Accumula REGISTER_ACCUMULATOR(stdDevSamp, genericParseSingleExpressionAccumulator<AccumulatorStdDevSamp>); REGISTER_EXPRESSION(stdDevPop, ExpressionFromAccumulator<AccumulatorStdDevPop>::parse); REGISTER_EXPRESSION(stdDevSamp, ExpressionFromAccumulator<AccumulatorStdDevSamp>::parse); -REGISTER_WINDOW_FUNCTION(stdDevPop, window_function::ExpressionFromAccumulator::parse); -REGISTER_WINDOW_FUNCTION(stdDevSamp, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(stdDevPop, + window_function::ExpressionFromAccumulator<AccumulatorStdDevPop>::parse); +REGISTER_WINDOW_FUNCTION(stdDevSamp, + window_function::ExpressionFromAccumulator<AccumulatorStdDevSamp>::parse); const char* AccumulatorStdDev::getOpName() const { return (_isSamp ? "$stdDevSamp" : "$stdDevPop"); diff --git a/src/mongo/db/pipeline/accumulator_sum.cpp b/src/mongo/db/pipeline/accumulator_sum.cpp index 175de17aabf..b4045333acb 100644 --- a/src/mongo/db/pipeline/accumulator_sum.cpp +++ b/src/mongo/db/pipeline/accumulator_sum.cpp @@ -46,7 +46,7 @@ using boost::intrusive_ptr; REGISTER_ACCUMULATOR(sum, genericParseSingleExpressionAccumulator<AccumulatorSum>); REGISTER_EXPRESSION(sum, ExpressionFromAccumulator<AccumulatorSum>::parse); -REGISTER_WINDOW_FUNCTION(sum, window_function::ExpressionFromAccumulator::parse); +REGISTER_WINDOW_FUNCTION(sum, window_function::ExpressionFromAccumulator<AccumulatorSum>::parse); const char* AccumulatorSum::getOpName() const { return "$sum"; diff --git a/src/mongo/db/pipeline/document_source_set_window_fields.cpp b/src/mongo/db/pipeline/document_source_set_window_fields.cpp index 9a4f1e0314e..eebfeafcfe0 100644 --- a/src/mongo/db/pipeline/document_source_set_window_fields.cpp +++ b/src/mongo/db/pipeline/document_source_set_window_fields.cpp @@ -268,40 +268,7 @@ boost::intrusive_ptr<DocumentSource> DocumentSourceInternalSetWindowFields::crea void DocumentSourceInternalSetWindowFields::initialize() { for (auto& wfs : _outputFields) { uassert(5397900, "Window function must be $sum", wfs.expr->getOpName() == "$sum"); - // TODO: SERVER-54340 Remove this check. - uassert(5397905, - "Window functions cannot set to dotted paths", - wfs.fieldName.find('.') == std::string::npos); - auto windowBounds = wfs.expr->bounds(); - stdx::visit( - visit_helper::Overloaded{ - [](const WindowBounds::DocumentBased& docBase) { - stdx::visit( - visit_helper::Overloaded{ - [](const WindowBounds::Unbounded) { /* pass */ }, - [](auto&& other) { - uasserted(5397904, - "Only 'unbounded' lower bound is currently supported"); - }}, - docBase.lower); - stdx::visit( - visit_helper::Overloaded{ - [](const WindowBounds::Current) { /* pass */ }, - [](auto&& other) { - uasserted(5397903, - "Only 'current' upper bound is currently supported"); - }}, - docBase.upper); - }, - [](const WindowBounds::RangeBased& rangeBase) { - uasserted(5397901, "Ranged based windows not currently supported"); - }, - [](const WindowBounds::TimeBased& timeBase) { - uasserted(5397902, "Time based windows are not currently supported"); - }}, - windowBounds.bounds); - _executableOutputs.push_back(ExecutableWindowFunction( - wfs.fieldName, AccumulatorSum::create(pExpCtx.get()), windowBounds, wfs.expr->input())); + _executableOutputs[wfs.fieldName] = WindowFunctionExec::create(&_iterator, wfs); } _init = true; } @@ -311,22 +278,28 @@ DocumentSource::GetNextResult DocumentSourceInternalSetWindowFields::doGetNext() initialize(); } - auto curStat = pSource->getNext(); - if (!curStat.isAdvanced()) { - return curStat; - } - auto curDoc = curStat.getDocument(); - if (_partitionBy) { - uassert(ErrorCodes::TypeMismatch, - "Cannot 'partitionBy' an expression of type array", - !_partitionBy->get()->evaluate(curDoc, &pExpCtx->variables).isArray()); + if (_eof) + return DocumentSource::GetNextResult::makeEOF(); + + // Populate the output document with the result from each window function. + MutableDocument outDoc(_iterator[0].get()); + for (auto&& [fieldName, function] : _executableOutputs) { + outDoc.setNestedField(fieldName, function->getNext()); } - MutableDocument outDoc(curDoc); - for (auto& output : _executableOutputs) { - // Currently only support unbounded windows and run on the merging shard -- we don't need - // to reset accumulators, merge states, or partition into multiple groups. - output.accumulator->process(output.inputExpr->evaluate(curDoc, &pExpCtx->variables), false); - outDoc.setNestedField(output.fieldName, output.accumulator->getValue(false)); + + // Advance the iterator and handle partition/EOF edge cases. + switch (_iterator.advance()) { + case PartitionIterator::AdvanceResult::kAdvanced: + break; + case PartitionIterator::AdvanceResult::kNewPartition: + // We've advanced to a new partition, reset the state of every function. + for (auto&& [_, function] : _executableOutputs) { + function->reset(); + } + break; + case PartitionIterator::AdvanceResult::kEOF: + _eof = true; + break; } return outDoc.freeze(); } diff --git a/src/mongo/db/pipeline/document_source_set_window_fields.h b/src/mongo/db/pipeline/document_source_set_window_fields.h index b5a31223c88..d293418a5bc 100644 --- a/src/mongo/db/pipeline/document_source_set_window_fields.h +++ b/src/mongo/db/pipeline/document_source_set_window_fields.h @@ -33,11 +33,15 @@ #include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_set_window_fields_gen.h" +#include "mongo/db/pipeline/window_function/partition_iterator.h" #include "mongo/db/pipeline/window_function/window_bounds.h" +#include "mongo/db/pipeline/window_function/window_function_exec.h" #include "mongo/db/pipeline/window_function/window_function_expression.h" namespace mongo { +class WindowFunctionExec; + struct WindowFunctionStatement { std::string fieldName; // top-level fieldname, not a path boost::intrusive_ptr<window_function::Expression> expr; @@ -53,22 +57,6 @@ struct WindowFunctionStatement { boost::optional<ExplainOptions::Verbosity> explain) const; }; -struct ExecutableWindowFunction { - std::string fieldName; - boost::intrusive_ptr<AccumulatorState> accumulator; - WindowBounds bounds; - boost::intrusive_ptr<Expression> inputExpr; - - ExecutableWindowFunction(std::string fieldName, - boost::intrusive_ptr<AccumulatorState> accumulator, - WindowBounds bounds, - boost::intrusive_ptr<Expression> input) - : fieldName(std::move(fieldName)), - accumulator(std::move(accumulator)), - bounds(std::move(bounds)), - inputExpr(std::move(input)) {} -}; - /** * $setWindowFields is an alias: it desugars to some combination of projection, sorting, * and $_internalSetWindowFields. @@ -106,7 +94,8 @@ public: : DocumentSource(kStageName, expCtx), _partitionBy(partitionBy), _sortBy(std::move(sortBy)), - _outputFields(std::move(outputFields)) {} + _outputFields(std::move(outputFields)), + _iterator(expCtx.get(), pSource, std::move(partitionBy)) {} StageConstraints constraints(Pipeline::SplitState pipeState) const final { return StageConstraints(StreamType::kBlocking, @@ -132,6 +121,11 @@ public: DocumentSource::GetNextResult doGetNext(); + void setSource(DocumentSource* source) final { + pSource = source; + _iterator.setSource(source); + } + private: DocumentSource::GetNextResult getNextInput(); void initialize(); @@ -139,8 +133,10 @@ private: boost::optional<boost::intrusive_ptr<Expression>> _partitionBy; boost::optional<SortPattern> _sortBy; std::vector<WindowFunctionStatement> _outputFields; - std::vector<ExecutableWindowFunction> _executableOutputs; + PartitionIterator _iterator; + StringMap<std::unique_ptr<WindowFunctionExec>> _executableOutputs; bool _init = false; + bool _eof = false; }; } // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/partition_iterator.cpp b/src/mongo/db/pipeline/window_function/partition_iterator.cpp index 435cb7a765b..9f2034c80f5 100644 --- a/src/mongo/db/pipeline/window_function/partition_iterator.cpp +++ b/src/mongo/db/pipeline/window_function/partition_iterator.cpp @@ -125,7 +125,7 @@ void PartitionIterator::getNextDocument() { auto doc = getNextRes.releaseDocument(); if (_partitionExpr) { - auto curKey = _partitionExpr->evaluate(doc, &_expCtx->variables); + auto curKey = (*_partitionExpr)->evaluate(doc, &_expCtx->variables); uassert(ErrorCodes::TypeMismatch, "Cannot 'partitionBy' an expression of type array", !curKey.isArray()); diff --git a/src/mongo/db/pipeline/window_function/partition_iterator.h b/src/mongo/db/pipeline/window_function/partition_iterator.h index bf7f0bd1097..3e645d8a768 100644 --- a/src/mongo/db/pipeline/window_function/partition_iterator.h +++ b/src/mongo/db/pipeline/window_function/partition_iterator.h @@ -42,7 +42,7 @@ class PartitionIterator { public: PartitionIterator(ExpressionContext* expCtx, DocumentSource* source, - boost::optional<const ExpressionFieldPath&> partitionExpr) + boost::optional<boost::intrusive_ptr<Expression>> partitionExpr) : _expCtx(expCtx), _source(source), _partitionExpr(partitionExpr), @@ -75,6 +75,13 @@ public: return _currentIndex; } + /** + * Sets the input DocumentSource for this iterator to 'source'. + */ + void setSource(DocumentSource* source) { + _source = source; + } + private: /** * Retrieves the next document from the prior stage and updates the state accordingly. @@ -98,7 +105,7 @@ private: ExpressionContext* _expCtx; DocumentSource* _source; - boost::optional<const ExpressionFieldPath&> _partitionExpr; + boost::optional<boost::intrusive_ptr<Expression>> _partitionExpr; std::vector<Document> _cache; int _currentIndex = 0; Value _partitionKey; diff --git a/src/mongo/db/pipeline/window_function/partition_iterator_test.cpp b/src/mongo/db/pipeline/window_function/partition_iterator_test.cpp index f0666f12f39..7c6f80b3031 100644 --- a/src/mongo/db/pipeline/window_function/partition_iterator_test.cpp +++ b/src/mongo/db/pipeline/window_function/partition_iterator_test.cpp @@ -84,7 +84,8 @@ TEST_F(PartitionIteratorTest, LookaheadOutOfRangeAccessNewPartition) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); ASSERT_DOCUMENT_EQ(docs[1].getDocument(), *partIter[1]); ASSERT_FALSE(partIter[2]); @@ -98,7 +99,8 @@ TEST_F(PartitionIteratorTest, AdvanceMovesCurrent) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); ASSERT_DOCUMENT_EQ(docs[1].getDocument(), *partIter[1]); ASSERT_FALSE(partIter[2]); @@ -116,7 +118,8 @@ TEST_F(PartitionIteratorTest, AdvanceOverPartitionBoundary) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); // First advance to the final document in partition with key "1". ASSERT_ADVANCE_RESULT(PartitionIterator::AdvanceResult::kAdvanced, partIter.advance()); @@ -133,7 +136,8 @@ TEST_F(PartitionIteratorTest, AdvanceResultsInEof) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); ASSERT_ADVANCE_RESULT(PartitionIterator::AdvanceResult::kEOF, partIter.advance()); @@ -149,7 +153,8 @@ TEST_F(PartitionIteratorTest, CurrentReturnsCorrectDocumentAsIteratorAdvances) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); partIter.advance(); ASSERT_DOCUMENT_EQ(docs[1].getDocument(), *partIter[0]); @@ -162,7 +167,8 @@ TEST_F(PartitionIteratorTest, EmptyCollectionReturnsEOF) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_FALSE(partIter[0]); ASSERT_ADVANCE_RESULT(PartitionIterator::AdvanceResult::kEOF, partIter.advance()); } @@ -173,7 +179,8 @@ TEST_F(PartitionIteratorTest, PartitionByArrayErrs) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_DOCUMENT_EQ(docs[0].getDocument(), *partIter[0]); ASSERT_THROWS_CODE(*partIter[1], AssertionException, ErrorCodes::TypeMismatch); } @@ -184,7 +191,8 @@ TEST_F(PartitionIteratorTest, CurrentOffsetIsCorrectAfterDocumentsAreAccessed) { const auto mock = DocumentSourceMock::createForTest(docs, getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "a", getExpCtx()->variablesParseState); - auto partIter = PartitionIterator(getExpCtx().get(), mock.get(), *key); + auto partIter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); ASSERT_EQ(0, partIter.getCurrentOffset()); auto doc = partIter[0]; partIter.advance(); diff --git a/src/mongo/db/pipeline/window_function/window_function_exec.cpp b/src/mongo/db/pipeline/window_function/window_function_exec.cpp new file mode 100644 index 00000000000..b4e0e96513d --- /dev/null +++ b/src/mongo/db/pipeline/window_function/window_function_exec.cpp @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2021-present MongoDB, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the Server Side Public License, version 1, + * as published by MongoDB, Inc. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * Server Side Public License for more details. + * + * You should have received a copy of the Server Side Public License + * along with this program. If not, see + * <http://www.mongodb.com/licensing/server-side-public-license>. + * + * As a special exception, the copyright holders give permission to link the + * code of portions of this program with the OpenSSL library under certain + * conditions as described in each individual source file and distribute + * linked combinations including the program with the OpenSSL library. You + * must comply with the Server Side Public License in all respects for + * all of the code used other than as permitted herein. If you modify file(s) + * with this exception, you may extend this exception to your version of the + * file(s), but you are not obligated to do so. If you do not wish to do so, + * delete this exception statement from your version. If you delete this + * exception statement from all source files in the program, then also delete + * it in the license file. + */ + +#include "mongo/db/pipeline/window_function/window_function_exec.h" +#include "mongo/db/pipeline/window_function/window_function_exec_non_removable.h" + +namespace mongo { + +namespace { + +std::unique_ptr<WindowFunctionExec> translateDocumentWindow( + PartitionIterator* iter, + boost::intrusive_ptr<window_function::Expression> expr, + const WindowBounds::DocumentBased& bounds) { + uassert(5397904, + "Only 'unbounded' lower bound is currently supported", + stdx::holds_alternative<WindowBounds::Unbounded>(bounds.lower)); + uassert(5397903, + "Only 'current' upper bound is currently supported", + stdx::holds_alternative<WindowBounds::Current>(bounds.upper)); + + // A left unbounded window will always be non-removable regardless of the upper + // bound. + return std::make_unique<WindowFunctionExecNonRemovable<AccumulatorState>>( + iter, expr->input(), expr->buildAccumulatorOnly(), bounds.upper); +} + +} // namespace + +std::unique_ptr<WindowFunctionExec> WindowFunctionExec::create( + PartitionIterator* iter, const WindowFunctionStatement& functionStmt) { + uassert(5397905, + "Window functions cannot set to dotted paths", + functionStmt.fieldName.find('.') == std::string::npos); + + // Use a sentinel variable to avoid a compilation error when some cases of std::visit don't + // return a value. + std::unique_ptr<WindowFunctionExec> exec; + stdx::visit( + visit_helper::Overloaded{ + [&](const WindowBounds::DocumentBased& docBase) { + exec = translateDocumentWindow(iter, functionStmt.expr, docBase); + }, + [&](const WindowBounds::RangeBased& rangeBase) { + uasserted(5397901, "Ranged based windows not currently supported"); + }, + [&](const WindowBounds::TimeBased& timeBase) { + uasserted(5397902, "Time based windows are not currently supported"); + }}, + functionStmt.expr->bounds().bounds); + return exec; +} + +} // namespace mongo diff --git a/src/mongo/db/pipeline/window_function/window_function_exec.h b/src/mongo/db/pipeline/window_function/window_function_exec.h index 4dfcbdd8356..5217be555d8 100644 --- a/src/mongo/db/pipeline/window_function/window_function_exec.h +++ b/src/mongo/db/pipeline/window_function/window_function_exec.h @@ -32,6 +32,7 @@ #include <queue> #include "mongo/db/pipeline/document_source.h" +#include "mongo/db/pipeline/document_source_set_window_fields.h" #include "mongo/db/pipeline/expression.h" #include "mongo/db/pipeline/window_function/partition_iterator.h" #include "mongo/db/pipeline/window_function/window_bounds.h" @@ -39,6 +40,8 @@ namespace mongo { +struct WindowFunctionStatement; + /** * An interface for an executor class capable of evaluating a function over a given window * definition. The function must expose an accumulate-type interface and potentially a remove @@ -49,7 +52,14 @@ namespace mongo { */ class WindowFunctionExec { public: - WindowFunctionExec(PartitionIterator* iter) : _iter(iter){}; + /** + * Creates an appropriate WindowFunctionExec that is capable of evaluating the window function + * over the given bounds, both found within the WindowFunctionStatement. + */ + static std::unique_ptr<WindowFunctionExec> create(PartitionIterator* iter, + const WindowFunctionStatement& functionStmt); + + virtual ~WindowFunctionExec() = default; /** * Retrieve the next value computed by the window function. @@ -62,6 +72,8 @@ public: virtual void reset() = 0; protected: + WindowFunctionExec(PartitionIterator* iter) : _iter(iter){}; + PartitionIterator* _iter; }; diff --git a/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp b/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp index 8687f45c35f..1629c87716a 100644 --- a/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp +++ b/src/mongo/db/pipeline/window_function/window_function_exec_test.cpp @@ -152,7 +152,8 @@ TEST_F(WindowFunctionExecNonRemovableTest, AccumulateOnlyWithMultiplePartitions) auto mock = DocumentSourceMock::createForTest(std::move(docs), getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto iter = PartitionIterator{getExpCtx().get(), mock.get(), *key}; + auto iter = PartitionIterator( + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)); auto input = ExpressionFieldPath::parse(getExpCtx().get(), "$a", getExpCtx()->variablesParseState); auto mgr = WindowFunctionExecNonRemovable<AccumulatorState>( @@ -328,7 +329,8 @@ TEST_F(WindowFunctionExecRemovableDocumentTest, CanResetFunction) { auto mock = DocumentSourceMock::createForTest(std::move(docs), getExpCtx()); auto key = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - auto iter = PartitionIterator{getExpCtx().get(), mock.get(), *key}; + auto iter = PartitionIterator{ + getExpCtx().get(), mock.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)}; auto input = ExpressionFieldPath::parse(getExpCtx().get(), "$a", getExpCtx()->variablesParseState); CollatorInterfaceMock collator = CollatorInterfaceMock::MockType::kToLowerString; @@ -355,7 +357,8 @@ TEST_F(WindowFunctionExecRemovableDocumentTest, CanResetFunction) { auto mockTwo = DocumentSourceMock::createForTest(std::move(docsTwo), getExpCtx()); auto keyTwo = ExpressionFieldPath::createPathFromString( getExpCtx().get(), "key", getExpCtx()->variablesParseState); - iter = PartitionIterator{getExpCtx().get(), mockTwo.get(), *key}; + iter = PartitionIterator{ + getExpCtx().get(), mockTwo.get(), boost::optional<boost::intrusive_ptr<Expression>>(key)}; input = ExpressionFieldPath::parse(getExpCtx().get(), "$a", getExpCtx()->variablesParseState); maxFunc = std::make_unique<WindowFunctionMax>(cmp); mgr = WindowFunctionExecRemovableDocument( diff --git a/src/mongo/db/pipeline/window_function/window_function_expression.h b/src/mongo/db/pipeline/window_function/window_function_expression.h index bf1fc38a55e..f02445314ee 100644 --- a/src/mongo/db/pipeline/window_function/window_function_expression.h +++ b/src/mongo/db/pipeline/window_function/window_function_expression.h @@ -29,6 +29,7 @@ #pragma once +#include "mongo/db/pipeline/accumulator.h" #include "mongo/db/pipeline/document_source.h" #include "mongo/db/pipeline/document_source_set_window_fields_gen.h" #include "mongo/db/pipeline/window_function/window_bounds.h" @@ -92,10 +93,13 @@ public: virtual boost::intrusive_ptr<::mongo::Expression> input() const = 0; + virtual boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const = 0; + private: static StringMap<Parser> parserMap; }; +template <typename NonRemovableType> class ExpressionFromAccumulator : public Expression { public: static boost::intrusive_ptr<Expression> parse(BSONElement elem, @@ -114,8 +118,8 @@ public: << " found an unknown argument: " << arg.fieldNameStringData(), allowedFields.find(arg.fieldNameStringData()) != allowedFields.end()); } - return make_intrusive<ExpressionFromAccumulator>( - std::move(accumulatorName), std::move(input), std::move(bounds)); + return make_intrusive<ExpressionFromAccumulator<NonRemovableType>>( + expCtx, std::move(accumulatorName), std::move(input), std::move(bounds)); } Value serialize(boost::optional<ExplainOptions::Verbosity> explain) const final { MutableDocument args; @@ -128,10 +132,12 @@ public: }}; } - ExpressionFromAccumulator(std::string accumulatorName, + ExpressionFromAccumulator(ExpressionContext* expCtx, + std::string accumulatorName, boost::intrusive_ptr<::mongo::Expression> input, WindowBounds bounds) - : _accumulatorName(std::move(accumulatorName)), + : _expCtx(expCtx), + _accumulatorName(std::move(accumulatorName)), _input(std::move(input)), _bounds(std::move(bounds)) {} @@ -147,8 +153,12 @@ public: return _bounds; } + boost::intrusive_ptr<AccumulatorState> buildAccumulatorOnly() const { + return NonRemovableType::create(_expCtx); + } private: + ExpressionContext* _expCtx; std::string _accumulatorName; boost::intrusive_ptr<::mongo::Expression> _input; WindowBounds _bounds; |