summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEthan Zhang <ethan.zhang@mongodb.com>2021-09-08 20:04:47 +0000
committerEvergreen Agent <no-reply@evergreen.mongodb.com>2021-09-08 20:56:11 +0000
commit810acb9723a4e74897ac446ed755221faa56e650 (patch)
tree31bb9940e6c6a515cac1de9bdc742716c87965a5
parent46b7bdbc6a3d9a3e181bf5a1b8c2739c0dc80838 (diff)
downloadmongo-810acb9723a4e74897ac446ed755221faa56e650.tar.gz
SERVER-59035 Flip sbeGroupCompatible flag for unsupported accumulators
-rw-r--r--src/mongo/db/pipeline/accumulation_statement.h11
-rw-r--r--src/mongo/db/pipeline/accumulator_js_reduce.cpp2
-rw-r--r--src/mongo/db/pipeline/accumulator_merge_objects.cpp4
-rw-r--r--src/mongo/db/pipeline/accumulator_multi.cpp2
-rw-r--r--src/mongo/db/pipeline/accumulator_std_dev.cpp6
-rw-r--r--src/mongo/db/pipeline/document_source_group.cpp2
-rw-r--r--src/mongo/db/pipeline/document_source_group.h7
-rw-r--r--src/mongo/db/pipeline/expression_context.h5
-rw-r--r--src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp97
9 files changed, 132 insertions, 4 deletions
diff --git a/src/mongo/db/pipeline/accumulation_statement.h b/src/mongo/db/pipeline/accumulation_statement.h
index 7144ada54fc..e473d5d7ea1 100644
--- a/src/mongo/db/pipeline/accumulation_statement.h
+++ b/src/mongo/db/pipeline/accumulation_statement.h
@@ -142,6 +142,17 @@ AccumulationExpression genericParseSingleExpressionAccumulator(ExpressionContext
}
/**
+ * A parser for any SBE unsupported accumulator that only takes a single expression as an argument.
+ * Returns the expression to be evaluated by the accumulator and an AccumulatorState::Factory.
+ */
+template <class AccName>
+AccumulationExpression genericParseSBEUnsupportedSingleExpressionAccumulator(
+ ExpressionContext* const expCtx, BSONElement elem, VariablesParseState vps) {
+ expCtx->sbeGroupCompatible = false;
+ return genericParseSingleExpressionAccumulator<AccName>(expCtx, elem, vps);
+}
+
+/**
* A parser that desugars { $count: {} } to { $sum: 1 }.
*/
inline AccumulationExpression parseCountAccumulator(ExpressionContext* const expCtx,
diff --git a/src/mongo/db/pipeline/accumulator_js_reduce.cpp b/src/mongo/db/pipeline/accumulator_js_reduce.cpp
index 4e23b219692..71075978d6d 100644
--- a/src/mongo/db/pipeline/accumulator_js_reduce.cpp
+++ b/src/mongo/db/pipeline/accumulator_js_reduce.cpp
@@ -44,6 +44,7 @@ AccumulationExpression AccumulatorInternalJsReduce::parseInternalJsReduce(
elem.type() == BSONType::Object);
BSONObj obj = elem.embeddedObject();
+ expCtx->sbeGroupCompatible = false;
std::string funcSource;
boost::intrusive_ptr<Expression> argument;
@@ -265,6 +266,7 @@ AccumulationExpression AccumulatorJs::parse(ExpressionContext* const expCtx,
elem.type() == BSONType::Object);
BSONObj obj = elem.embeddedObject();
+ expCtx->sbeGroupCompatible = false;
std::string init, accumulate, merge;
boost::optional<std::string> finalize;
boost::intrusive_ptr<Expression> initArgs, accumulateArgs;
diff --git a/src/mongo/db/pipeline/accumulator_merge_objects.cpp b/src/mongo/db/pipeline/accumulator_merge_objects.cpp
index 43e2979e3ab..fc69a1104b8 100644
--- a/src/mongo/db/pipeline/accumulator_merge_objects.cpp
+++ b/src/mongo/db/pipeline/accumulator_merge_objects.cpp
@@ -41,8 +41,8 @@ using boost::intrusive_ptr;
/* ------------------------- AccumulatorMergeObjects ----------------------------- */
-REGISTER_ACCUMULATOR(mergeObjects,
- genericParseSingleExpressionAccumulator<AccumulatorMergeObjects>);
+REGISTER_ACCUMULATOR(
+ mergeObjects, genericParseSBEUnsupportedSingleExpressionAccumulator<AccumulatorMergeObjects>);
REGISTER_STABLE_EXPRESSION(mergeObjects, ExpressionFromAccumulator<AccumulatorMergeObjects>::parse);
intrusive_ptr<AccumulatorState> AccumulatorMergeObjects::create(ExpressionContext* const expCtx) {
diff --git a/src/mongo/db/pipeline/accumulator_multi.cpp b/src/mongo/db/pipeline/accumulator_multi.cpp
index 4dc6c807dce..78bd99e740d 100644
--- a/src/mongo/db/pipeline/accumulator_multi.cpp
+++ b/src/mongo/db/pipeline/accumulator_multi.cpp
@@ -173,6 +173,7 @@ template <MinMaxSense s>
AccumulationExpression AccumulatorMinMaxN::parseMinMaxN(ExpressionContext* const expCtx,
BSONElement elem,
VariablesParseState vps) {
+ expCtx->sbeGroupCompatible = false;
auto name = [] {
if constexpr (s == MinMaxSense::kMin) {
return AccumulatorMinN::getName();
@@ -269,6 +270,7 @@ template <FirstLastSense v>
AccumulationExpression AccumulatorFirstLastN::parseFirstLastN(ExpressionContext* const expCtx,
BSONElement elem,
VariablesParseState vps) {
+ expCtx->sbeGroupCompatible = false;
auto name = [] {
if constexpr (v == Sense::kFirst) {
return AccumulatorFirstN::getName();
diff --git a/src/mongo/db/pipeline/accumulator_std_dev.cpp b/src/mongo/db/pipeline/accumulator_std_dev.cpp
index ce2e5ac3e77..126a02af705 100644
--- a/src/mongo/db/pipeline/accumulator_std_dev.cpp
+++ b/src/mongo/db/pipeline/accumulator_std_dev.cpp
@@ -42,8 +42,10 @@
namespace mongo {
using boost::intrusive_ptr;
-REGISTER_ACCUMULATOR(stdDevPop, genericParseSingleExpressionAccumulator<AccumulatorStdDevPop>);
-REGISTER_ACCUMULATOR(stdDevSamp, genericParseSingleExpressionAccumulator<AccumulatorStdDevSamp>);
+REGISTER_ACCUMULATOR(stdDevPop,
+ genericParseSBEUnsupportedSingleExpressionAccumulator<AccumulatorStdDevPop>);
+REGISTER_ACCUMULATOR(stdDevSamp,
+ genericParseSBEUnsupportedSingleExpressionAccumulator<AccumulatorStdDevSamp>);
REGISTER_STABLE_EXPRESSION(stdDevPop, ExpressionFromAccumulator<AccumulatorStdDevPop>::parse);
REGISTER_STABLE_EXPRESSION(stdDevSamp, ExpressionFromAccumulator<AccumulatorStdDevSamp>::parse);
REGISTER_REMOVABLE_WINDOW_FUNCTION(stdDevPop, AccumulatorStdDevPop, WindowFunctionStdDevPop);
diff --git a/src/mongo/db/pipeline/document_source_group.cpp b/src/mongo/db/pipeline/document_source_group.cpp
index 6f9b20fdf5a..6f85a7c84a2 100644
--- a/src/mongo/db/pipeline/document_source_group.cpp
+++ b/src/mongo/db/pipeline/document_source_group.cpp
@@ -470,6 +470,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::createFromBson(
BSONObj groupObj(elem.Obj());
BSONObjIterator groupIterator(groupObj);
VariablesParseState vps = expCtx->variablesParseState;
+ expCtx->sbeGroupCompatible = true;
while (groupIterator.more()) {
BSONElement groupField(groupIterator.next());
StringData pFieldName = groupField.fieldNameStringData();
@@ -490,6 +491,7 @@ intrusive_ptr<DocumentSource> DocumentSourceGroup::createFromBson(
groupStage->_memoryTracker.set(pFieldName, 0);
}
}
+ groupStage->_sbeCompatible = expCtx->sbeGroupCompatible;
uassert(
15955, "a group specification must include an _id", !groupStage->_idExpressions.empty());
diff --git a/src/mongo/db/pipeline/document_source_group.h b/src/mongo/db/pipeline/document_source_group.h
index 4a00a4fea7b..1e2e1c1d3d4 100644
--- a/src/mongo/db/pipeline/document_source_group.h
+++ b/src/mongo/db/pipeline/document_source_group.h
@@ -189,6 +189,11 @@ public:
*/
size_t getMaxMemoryUsageBytes() const;
+ // True if this $group can be pushed down to SBE.
+ bool sbeCompatible() const {
+ return _sbeCompatible;
+ }
+
protected:
GetNextResult doGetNext() final;
void doDispose() final;
@@ -289,6 +294,8 @@ private:
std::unique_ptr<Sorter<Value, Value>::Iterator> _sorterIterator;
std::pair<Value, Value> _firstPartOfNextGroup;
+
+ bool _sbeCompatible;
};
} // namespace mongo
diff --git a/src/mongo/db/pipeline/expression_context.h b/src/mongo/db/pipeline/expression_context.h
index 8a5b82f0254..88340b8b846 100644
--- a/src/mongo/db/pipeline/expression_context.h
+++ b/src/mongo/db/pipeline/expression_context.h
@@ -410,6 +410,11 @@ public:
// SBE expressions.
bool sbeCompatible = true;
+ // True if all accumulators in the $group stage currently being parsed using this expression
+ // context can be translated into equivalent SBE expressions. This value is transient and gets
+ // reset for every $group stage we parse. Each $group stage has their per-stage flag.
+ bool sbeGroupCompatible = true;
+
// These fields can be used in a context when API version validations were not enforced during
// parse time (Example creating a view or validator), but needs to be enforce while querying
// later.
diff --git a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
index 2147271889b..fcaab1cf532 100644
--- a/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
+++ b/src/mongo/db/query/sbe_stage_builder_accumulator_test.cpp
@@ -27,7 +27,10 @@
* it in the license file.
*/
+#include <fmt/printf.h>
+
#include "mongo/db/exec/sbe/util/debug_print.h"
+#include "mongo/db/pipeline/document_source_group.h"
#include "mongo/db/pipeline/expression_context_for_test.h"
#include "mongo/db/query/canonical_query.h"
#include "mongo/db/query/collation/collator_interface_mock.h"
@@ -312,6 +315,42 @@ protected:
<< " but got set: " << std::make_pair(aggregatedTag, aggregatedSet);
}
+ void runSbeGroupCompatibleFlagTest(const std::vector<BSONObj>& rawPipeline,
+ boost::intrusive_ptr<ExpressionContext>& expCtx) {
+ // When we parse the AccumulationExpressions to build the DocumentSourceGroup, those
+ // AccumulationExpressions that are not supported by SBE will flip the sbeGroupCompatible
+ // flag in the expCtx to false.
+ auto pipeline = Pipeline::parse(rawPipeline, expCtx);
+
+ sbe::RuntimeEnvironment env;
+ auto state = makeStageBuilderState(&env);
+ for (const auto& source : pipeline->getSources()) {
+ // We try to figure out the expected sbeGroupCompatible value here. The
+ // sbeGroupCompatible flag should be false if any accumulator being tested does not have
+ // a registered SBE accumulator builder function.
+ auto groupStage = dynamic_cast<DocumentSourceGroup*>(source.get());
+ ASSERT_TRUE(groupStage);
+
+ auto sbeGroupCompatible = true;
+ for (const AccumulationStatement& accStmt : groupStage->getAccumulatedFields()) {
+ stage_builder::EvalStage evalStage;
+ auto [argExpr, argStage] = stage_builder::buildArgument(
+ state, accStmt, std::move(evalStage), 0, kEmptyPlanNodeId);
+ try {
+ auto [aggExprs, accStage] = stage_builder::buildAccumulator(
+ state, accStmt, std::move(argStage), std::move(argExpr), kEmptyPlanNodeId);
+ } catch (const DBException& e) {
+ // The accumulator is unsupported in SBE, so we expect that the sbeCompatible
+ // flag should be false.
+ ASSERT_EQ(5754701, e.code());
+ sbeGroupCompatible = false;
+ break;
+ }
+ }
+ ASSERT_EQ(sbeGroupCompatible, groupStage->sbeCompatible());
+ }
+ }
+
private:
// The slot id generator should be shared across all stages via the 'SlotBasedStageBuilder' but
// for now we will workaround by creating a generator with large enough starting id.
@@ -1026,6 +1065,7 @@ TEST_F(SbeAccumulatorBuilderTest, SumAccumulatorTranslationTwoGroupByTest) {
};
runAggregationWithGroupByTest("{x: {$sum: '$b'}}", docs, {"$a", "$c"}, BSON_ARRAY(20));
}
+
TEST_F(SbeAccumulatorBuilderTest, AddToSetAccumulatorTranslationSingleDoc) {
auto docs = std::vector<BSONArray>{BSON_ARRAY(BSON("a" << 1 << "b" << 1))};
runAddToSetTest("{x: {$addToSet: '$b'}}", docs, BSON_ARRAY(1));
@@ -1151,4 +1191,61 @@ TEST_F(SbeAccumulatorBuilderTest, PushAccumulatorTranslationVariousTypes) {
BSON_ARRAY(BSON_ARRAY(42 << 4.2 << true << strVal << bsonObj << bsonArr)));
}
+class AccumulatorSBEIncompatible final : public AccumulatorState {
+public:
+ static constexpr auto kName = "$incompatible"_sd;
+ const char* getOpName() const final {
+ return kName.rawData();
+ }
+ explicit AccumulatorSBEIncompatible(ExpressionContext* expCtx) : AccumulatorState(expCtx) {}
+ void processInternal(const Value& input, bool merging) final {}
+ Value getValue(bool toBeMerged) final {
+ return Value(true);
+ }
+ void reset() final {}
+ static boost::intrusive_ptr<AccumulatorState> create(ExpressionContext* expCtx) {
+ return new AccumulatorSBEIncompatible(expCtx);
+ }
+};
+REGISTER_ACCUMULATOR(
+ incompatible,
+ genericParseSBEUnsupportedSingleExpressionAccumulator<AccumulatorSBEIncompatible>);
+
+TEST_F(SbeAccumulatorBuilderTest, SbeGroupCompatibleFlag) {
+ std::vector<BSONArray> docs;
+ std::vector<std::string> testCases = {
+ "agg: {$addToSet: \"$item\"}",
+ "agg: {$avg: \"$quantity\"}",
+ "agg: {$first: \"$item\"}",
+ "agg: {$last: \"$item\"}",
+ // TODO (SERVER-51541): Uncomment the following two test cases when $object supported is
+ // added to SBE.
+ // "agg: {$_internalJsReduce: {data: {k: \"$word\", v: \"$val\"}, eval: \"null\"}}",
+ //
+ // R"'(agg: {$accumulator: {init: "a", accumulate: "b", accumulateArgs: ["$copies"], merge:
+ // "c", lang: "js"}})'",
+ "agg: {$mergeObjects: \"$item\"}",
+ "agg: {$min: \"$item\"}",
+ "agg: {$max: \"$item\"}",
+ "agg: {$push: \"$item\"}",
+ "agg: {$stdDevPop: \"$item\"}",
+ "agg: {$stdDevSamp: \"$item\"}",
+ "agg: {$sum: \"$item\"}",
+ // All supported case.
+ "agg1: {$sum: \"$item\"}, agg2: {$max: \"$item\"}, agg3: {$avg: \"$quantity\"}",
+ // Mix of supported/unsupported accumulators.
+ "agg1: {$sum: \"$item\"}, agg2: {$incompatible: \"$item\"}, agg3: {$avg: \"$a\"}",
+ "agg1: {$incompatible: \"$item\"}, agg2: {$min: \"$item\"}, agg3: {$avg: \"$quantity\"}",
+ };
+ boost::intrusive_ptr<ExpressionContext> expCtx(new ExpressionContextForTest());
+ std::vector<BSONObj> rawPipelines;
+ rawPipelines.reserve(testCases.size());
+ for (auto testCase : testCases) {
+ auto groupObj = fromjson(fmt::sprintf("{$group: {%s, _id: null}}", testCase));
+ rawPipelines.push_back(groupObj);
+ runSbeGroupCompatibleFlagTest(makeVector(groupObj), expCtx);
+ }
+ runSbeGroupCompatibleFlagTest(rawPipelines, expCtx);
+}
+
} // namespace mongo